mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
10 Commits
add-attn-m
...
cogvideox/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43ec0bd224 | ||
|
|
9aa2e97e69 | ||
|
|
e07fe043a8 | ||
|
|
cce65ab93f | ||
|
|
052eeb5da9 | ||
|
|
c9454bd4f1 | ||
|
|
2e7502f810 | ||
|
|
17b7f8ab8e | ||
|
|
a012fa5748 | ||
|
|
6e03e72cff |
@@ -1054,9 +1054,8 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
accumulated_values = torch.zeros_like(hidden_states)
|
||||
|
||||
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
||||
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
||||
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
||||
# essentially a non-multiple of `context_length`.
|
||||
# The reason for slicing here is to handle cases like frame_indices=[(0, 16), (16, 20)],
|
||||
# if the user provided a video with 19 frames, or essentially a non-multiple of `context_length`.
|
||||
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
||||
weights *= frame_weights
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -354,9 +354,12 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.sample_height = sample_height
|
||||
self.text_embed_dim = text_embed_dim
|
||||
self.bias = bias
|
||||
self.sample_width = sample_width
|
||||
self.sample_height = sample_height
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
@@ -377,7 +380,6 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.embed_dim,
|
||||
@@ -387,12 +389,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
self.temporal_interpolation_scale,
|
||||
)
|
||||
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
||||
joint_pos_embedding = torch.zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
||||
|
||||
return joint_pos_embedding
|
||||
return pos_embedding
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
r"""
|
||||
@@ -409,11 +406,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]
|
||||
|
||||
if self.use_positional_embeddings:
|
||||
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
@@ -423,13 +416,114 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
or self.sample_frames != pre_time_compression_frames
|
||||
):
|
||||
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
||||
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
||||
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
embeds = embeds + pos_embedding
|
||||
image_embeds = image_embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
return text_embeds, image_embeds
|
||||
|
||||
|
||||
class FreeNoiseCogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
bias: bool = True,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_positional_embeddings: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.text_embed_dim = text_embed_dim
|
||||
self.bias = bias
|
||||
self.sample_width = sample_width
|
||||
self.sample_height = sample_height
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
if use_positional_embeddings:
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
# Copied from diffusers.models.embeddings.CogVideoXPatchEmbed
|
||||
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.embed_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
)
|
||||
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
||||
return pos_embedding
|
||||
|
||||
def forward(self, text_embeds: Union[torch.Tensor, Tuple[Dict[int, torch.Tensor]]], image_embeds: torch.Tensor):
|
||||
r"""
|
||||
Args:
|
||||
text_embeds (`torch.Tensor`):
|
||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
||||
image_embeds (`torch.Tensor`):
|
||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
||||
"""
|
||||
if isinstance(text_embeds, torch.Tensor):
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
else:
|
||||
assert isinstance(text_embeds, tuple)
|
||||
text_embeds_output = []
|
||||
for tuple_index in range(len(text_embeds)):
|
||||
text_embeds_output.append({})
|
||||
for key, text_embed in list(text_embeds[tuple_index].items()):
|
||||
text_embeds_output[tuple_index][key] = self.text_proj(text_embed)
|
||||
text_embeds = tuple(text_embeds_output)
|
||||
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]
|
||||
|
||||
if self.use_positional_embeddings:
|
||||
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
if (
|
||||
self.sample_height != height
|
||||
or self.sample_width != width
|
||||
or self.sample_frames != pre_time_compression_frames
|
||||
):
|
||||
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
||||
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
image_embeds = image_embeds + pos_embedding
|
||||
|
||||
return text_embeds, image_embeds
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -86,6 +86,20 @@ class CogVideoXBlock(nn.Module):
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.time_embed_dim = time_embed_dim
|
||||
self.dropout = dropout
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.qk_norm = qk_norm
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.norm_eps = norm_eps
|
||||
self.final_dropout = final_dropout
|
||||
self.ff_inner_dim = ff_inner_dim
|
||||
self.ff_bias = ff_bias
|
||||
self.attention_out_bias = attention_out_bias
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
@@ -119,6 +133,7 @@ class CogVideoXBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
num_frames: int = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
@@ -152,6 +167,410 @@ class CogVideoXBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# norm_final is just nn.LayerNorm, all ops will be on channel dimension, so we dont have to care about frame dimension
|
||||
# proj_out is also just along channel dimension
|
||||
# norm_out is just linear and nn.LayerNorm, again only on channel dimension, so we dont care about frame dims
|
||||
# same story with norm1, norm2 and ff
|
||||
# patch embed layer just applies on channel dim too and condenses to [B, FHW, C]
|
||||
# only attention layer seems to be actually doing anything with the frame dimension and so only location where FreeNoise needs to be applied
|
||||
# Since it does not matter for norm1, norm2, ff, and they might create memory bottleneck, just use FreeNoise frame split on them too
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FreeNoiseCogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
FreeNoise block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
context_length: int = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
prompt_interpolation_callback: Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
prompt_pooling_callback: Callable[[List[torch.Tensor]], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.time_embed_dim = time_embed_dim
|
||||
self.dropout = dropout
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.qk_norm = qk_norm
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.norm_eps = norm_eps
|
||||
self.final_dropout = final_dropout
|
||||
self.ff_inner_dim = ff_inner_dim
|
||||
self.ff_bias = ff_bias
|
||||
self.attention_out_bias = attention_out_bias
|
||||
|
||||
self.set_free_noise_properties(
|
||||
context_length, context_stride, weighting_scheme, prompt_interpolation_callback, prompt_pooling_callback
|
||||
)
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# Copied from diffusers.models.attention.FreeNoiseTransformerBlock.set_free_noise_properties
|
||||
def set_free_noise_properties(
|
||||
self,
|
||||
context_length: int,
|
||||
context_stride: int,
|
||||
weighting_scheme: str = "pyramid",
|
||||
prompt_interpolation_callback: Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
prompt_pooling_callback: Callable[[List[torch.Tensor]], torch.Tensor] = None,
|
||||
) -> None:
|
||||
if prompt_interpolation_callback is None:
|
||||
raise ValueError("Must pass a callback to interpolate between prompt embeddings.")
|
||||
if prompt_pooling_callback is None:
|
||||
raise ValueError("Must pass a callback to pool prompt embeddings.")
|
||||
self.context_length = context_length
|
||||
self.context_stride = context_stride
|
||||
self.weighting_scheme = weighting_scheme
|
||||
self.prompt_interpolation_callback = prompt_interpolation_callback
|
||||
self.prompt_pooling_callback = prompt_pooling_callback
|
||||
|
||||
# Copied from diffusers.models.attention.FreeNoiseTransformerBlock._get_frame_indices
|
||||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||
frame_indices = []
|
||||
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||
window_start = i
|
||||
window_end = min(num_frames, i + self.context_length)
|
||||
frame_indices.append((window_start, window_end))
|
||||
return frame_indices
|
||||
|
||||
# Copied from diffusers.models.attention.FreeNoiseTransformerBlock._get_frame_weights
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
if weighting_scheme == "flat":
|
||||
weights = [1.0] * num_frames
|
||||
|
||||
elif weighting_scheme == "pyramid":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [1, 2, 2, 1]
|
||||
mid = num_frames // 2
|
||||
weights = list(range(1, mid + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||
mid = (num_frames + 1) // 2
|
||||
weights = list(range(1, mid))
|
||||
weights = weights + [mid] + weights[::-1]
|
||||
|
||||
elif weighting_scheme == "delayed_reverse_sawtooth":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [0.01, 2, 2, 1]
|
||||
mid = num_frames // 2
|
||||
weights = [0.01] * (mid - 1) + [mid]
|
||||
weights = weights + list(range(mid, 0, -1))
|
||||
else:
|
||||
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
|
||||
mid = (num_frames + 1) // 2
|
||||
weights = [0.01] * mid
|
||||
weights = weights + list(range(mid, 0, -1))
|
||||
else:
|
||||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||
|
||||
return weights
|
||||
|
||||
def _prepare_free_noise_encoder_hidden_states(
|
||||
self,
|
||||
encoder_hidden_states: Union[
|
||||
torch.Tensor, List[torch.Tensor], Tuple[Dict[int, torch.Tensor], Optional[Dict[int, torch.Tensor]]]
|
||||
],
|
||||
frame_indices: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
if torch.is_tensor(encoder_hidden_states):
|
||||
encoder_hidden_states = [encoder_hidden_states.clone() for _ in range(len(frame_indices))]
|
||||
|
||||
elif isinstance(encoder_hidden_states, tuple):
|
||||
print("frame_indices:", frame_indices)
|
||||
pooled_prompt_embeds_list = []
|
||||
pooled_negative_prompt_embeds_list = []
|
||||
negative_prompt_embeds_dict, prompt_embeds_dict = encoder_hidden_states
|
||||
last_frame_start = 0
|
||||
|
||||
# For every batch of frames that is to be processed, pool the positive and negative prompt embeddings.
|
||||
# TODO(aryan): Since this is experimental, I didn't try many different things. I found from testing
|
||||
# that pooling with previous batch frame embeddings necessary to produce better results and help with
|
||||
# prompt transitions.
|
||||
for frame_start, frame_end in frame_indices:
|
||||
pooled_prompt_embeds = None
|
||||
pooled_negative_prompt_embeds = None
|
||||
|
||||
pooling_list = [
|
||||
prompt_embeds_dict[i] for i in range(last_frame_start, frame_end) if i in prompt_embeds_dict
|
||||
]
|
||||
if len(pooling_list) > 0:
|
||||
print("pooling", [i for i in range(last_frame_start, frame_end) if i in prompt_embeds_dict])
|
||||
pooled_prompt_embeds = self.prompt_pooling_callback(pooling_list)
|
||||
print("after pooling:", pooled_prompt_embeds.isnan().any())
|
||||
|
||||
if negative_prompt_embeds_dict is not None:
|
||||
pooling_list = [
|
||||
negative_prompt_embeds_dict[i]
|
||||
for i in range(last_frame_start, frame_end)
|
||||
if i in negative_prompt_embeds_dict
|
||||
]
|
||||
if len(pooling_list) > 0:
|
||||
print(
|
||||
"negative pooling", [i for i in range(last_frame_start, frame_end) if i in prompt_embeds_dict]
|
||||
)
|
||||
pooled_negative_prompt_embeds = self.prompt_pooling_callback(pooling_list)
|
||||
print("after negative pooling:", pooled_negative_prompt_embeds.isnan().any())
|
||||
|
||||
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
||||
pooled_negative_prompt_embeds_list.append(pooled_negative_prompt_embeds)
|
||||
last_frame_start = frame_start
|
||||
|
||||
assert pooled_prompt_embeds_list[0] is not None
|
||||
assert pooled_prompt_embeds[-1] is not None
|
||||
if negative_prompt_embeds_dict is not None:
|
||||
assert pooled_negative_prompt_embeds_list[0] is not None
|
||||
assert pooled_negative_prompt_embeds_list[-1] is not None
|
||||
|
||||
# If there were no relevant prompts for certain frame batches, interpolate and fill in the gaps
|
||||
last_existent_embed_index = 0
|
||||
for i in range(1, len(frame_indices)):
|
||||
if pooled_prompt_embeds_list[i] is not None and i - last_existent_embed_index > 1:
|
||||
print("interpolating:", last_existent_embed_index, i)
|
||||
interpolated_embeds = self.prompt_interpolation_callback(
|
||||
last_existent_embed_index,
|
||||
i,
|
||||
pooled_prompt_embeds_list[last_existent_embed_index],
|
||||
pooled_prompt_embeds_list[i],
|
||||
)
|
||||
print("after interpolating", interpolated_embeds.isnan().any())
|
||||
pooled_prompt_embeds_list[last_existent_embed_index : i + 1] = interpolated_embeds.split(1, dim=0)
|
||||
last_existent_embed_index = i
|
||||
assert all(x is not None for x in pooled_prompt_embeds_list)
|
||||
|
||||
if negative_prompt_embeds_dict is not None:
|
||||
last_existent_embed_index = 0
|
||||
for i in range(1, len(frame_indices)):
|
||||
if pooled_negative_prompt_embeds_list[i] is not None and i - last_existent_embed_index > 1:
|
||||
print("negative interpolating:", last_existent_embed_index, i)
|
||||
interpolated_embeds = self.prompt_interpolation_callback(
|
||||
last_existent_embed_index,
|
||||
i,
|
||||
pooled_negative_prompt_embeds_list[last_existent_embed_index],
|
||||
pooled_negative_prompt_embeds_list[i],
|
||||
)
|
||||
print("after negative interpolating", interpolated_embeds.isnan().any())
|
||||
pooled_negative_prompt_embeds_list[
|
||||
last_existent_embed_index : i + 1
|
||||
] = interpolated_embeds.split(1, dim=0)
|
||||
last_existent_embed_index = i
|
||||
assert all(x is not None for x in pooled_negative_prompt_embeds_list)
|
||||
|
||||
if negative_prompt_embeds_dict is not None:
|
||||
# Classifier-Free Guidance
|
||||
pooled_prompt_embeds_list = [
|
||||
torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
for negative_prompt_embeds, prompt_embeds in zip(
|
||||
pooled_negative_prompt_embeds_list, pooled_prompt_embeds_list
|
||||
)
|
||||
]
|
||||
|
||||
encoder_hidden_states = pooled_prompt_embeds_list
|
||||
|
||||
elif not isinstance(encoder_hidden_states, list):
|
||||
raise ValueError(
|
||||
f"Expected `encoder_hidden_states` to be a tensor, list of tensor, or a tuple of dictionaries, but found {type(encoder_hidden_states)=}"
|
||||
)
|
||||
|
||||
assert isinstance(encoder_hidden_states, list) and len(encoder_hidden_states) == len(frame_indices)
|
||||
return encoder_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
num_frames: int = None,
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [B, F x H x W, C]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
frame_indices = self._get_frame_indices(num_frames)
|
||||
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
||||
frame_weights = (
|
||||
torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
)
|
||||
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
||||
|
||||
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
||||
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
||||
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
||||
if not is_last_frame_batch_complete:
|
||||
if num_frames < self.context_length:
|
||||
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
||||
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
||||
frame_indices.append((num_frames - self.context_length, num_frames))
|
||||
|
||||
# Unflatten frame dimension: [B, F, HW, C]
|
||||
batch_size, frames_height_width, channels = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(batch_size, num_frames, frames_height_width // num_frames, channels)
|
||||
encoder_hidden_states = self._prepare_free_noise_encoder_hidden_states(encoder_hidden_states, frame_indices)
|
||||
|
||||
num_times_accumulated = torch.zeros((1, num_frames, 1, 1), device=device)
|
||||
accumulated_values = torch.zeros_like(hidden_states)
|
||||
|
||||
text_seq_length = _get_text_seq_length(encoder_hidden_states)
|
||||
|
||||
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
||||
# The reason for slicing here is to handle cases like frame_indices=[(0, 16), (16, 20)],
|
||||
# if the user provided a video with 19 frames, or essentially a non-multiple of `context_length`.
|
||||
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
||||
weights *= frame_weights
|
||||
|
||||
# Flatten frame dimension: [B, F'HW, C]
|
||||
hidden_states_chunk = hidden_states[:, frame_start:frame_end].flatten(1, 2)
|
||||
print(
|
||||
"debug:",
|
||||
text_seq_length,
|
||||
torch.isnan(hidden_states_chunk).any(),
|
||||
torch.isnan(encoder_hidden_states[i]).any(),
|
||||
)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states_chunk, encoder_hidden_states[i], temb
|
||||
)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states_chunk = hidden_states_chunk + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states[i] = encoder_hidden_states[i] + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states_chunk, encoder_hidden_states[i], temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states_chunk = hidden_states_chunk + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states[i] = encoder_hidden_states[i] + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
# Unflatten frame dimension: [B, F', HW, C]
|
||||
_num_frames = frame_end - frame_start
|
||||
hidden_states_chunk = hidden_states_chunk.reshape(batch_size, _num_frames, -1, channels)
|
||||
|
||||
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
||||
accumulated_values[:, -last_frame_batch_length:] += (
|
||||
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
||||
)
|
||||
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
||||
else:
|
||||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||
|
||||
# TODO(aryan): Maybe this could be done in a better way.
|
||||
#
|
||||
# Previously, this was:
|
||||
# hidden_states = torch.where(
|
||||
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
# )
|
||||
#
|
||||
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
|
||||
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
|
||||
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
|
||||
# looked into this deeply because other memory optimizations led to more pronounced reductions.
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
|
||||
for accumulated_split, num_times_split in zip(
|
||||
accumulated_values.split(self.context_length, dim=1),
|
||||
num_times_accumulated.split(self.context_length, dim=1),
|
||||
)
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
|
||||
# Flatten frame dimension: [B, FHW, C]
|
||||
hidden_states = hidden_states.flatten(1, 2)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
@@ -417,13 +836,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
text_seq_length = _get_text_seq_length(encoder_hidden_states)
|
||||
encoder_hidden_states, hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -441,6 +857,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
num_frames,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -449,6 +866,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
@@ -472,3 +890,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
def _get_text_seq_length(x) -> int:
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.shape[1]
|
||||
if isinstance(x, list):
|
||||
return _get_text_seq_length(x[0])
|
||||
if isinstance(x, dict):
|
||||
return _get_text_seq_length(next(iter(x.values())))
|
||||
return None
|
||||
|
||||
@@ -28,6 +28,7 @@ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_noise_utils import CogVideoXFreeNoiseMixin
|
||||
from .pipeline_output import CogVideoXPipelineOutput
|
||||
|
||||
|
||||
@@ -136,7 +137,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogVideoXPipeline(DiffusionPipeline):
|
||||
class CogVideoXPipeline(DiffusionPipeline, CogVideoXFreeNoiseMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
@@ -316,9 +317,17 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_frames,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
@@ -392,8 +401,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -470,7 +479,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt: Optional[Union[str, List[str], Dict[int, str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
@@ -497,7 +506,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]` or `Dict[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
@@ -569,11 +578,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
@@ -595,7 +599,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -610,18 +614,33 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
prompt_embeds_dtype = next(iter(prompt_embeds[0].values())).dtype
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
@@ -635,7 +654,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
prompt_embeds_dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
@@ -699,7 +718,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
latents = latents.to(prompt_embeds_dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
|
||||
@@ -12,13 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
|
||||
from ..models.embeddings import CogVideoXPatchEmbed, FreeNoiseCogVideoXPatchEmbed
|
||||
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ..models.transformers.cogvideox_transformer_3d import (
|
||||
CogVideoXBlock,
|
||||
CogVideoXTransformer3DModel,
|
||||
FreeNoiseCogVideoXBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_2d import Transformer2DModel
|
||||
from ..models.unets.unet_motion_model import (
|
||||
AnimateDiffTransformer3D,
|
||||
@@ -26,7 +34,6 @@ from ..models.unets.unet_motion_model import (
|
||||
DownBlockMotion,
|
||||
UpBlockMotion,
|
||||
)
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -34,6 +41,56 @@ from ..utils.torch_utils import randn_tensor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _lerp(start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor) -> torch.Tensor:
|
||||
num_indices = end_index - start_index + 1
|
||||
interpolated_tensors = []
|
||||
|
||||
for i in range(num_indices):
|
||||
alpha = i / (num_indices - 1)
|
||||
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
|
||||
interpolated_tensors.append(interpolated_tensor)
|
||||
|
||||
interpolated_tensors = torch.cat(interpolated_tensors)
|
||||
return interpolated_tensors
|
||||
|
||||
|
||||
def _weighted_pooling(
|
||||
hidden_states: List[torch.Tensor], pooling_type: str = "inverse_distance", decay_rate: float = 0.9
|
||||
) -> torch.Tensor:
|
||||
length = len(hidden_states)
|
||||
|
||||
if length == 0:
|
||||
raise ValueError("This method cannot be called with an empty list.")
|
||||
if length == 1:
|
||||
return hidden_states[0]
|
||||
|
||||
if pooling_type == "average":
|
||||
weights = [1] * length
|
||||
elif pooling_type == "exponential_decay":
|
||||
weights = [decay_rate**i for i in range(length)]
|
||||
elif pooling_type == "reverse_exponential_decay":
|
||||
weights = [decay_rate**i for i in range(length)][::-1]
|
||||
elif pooling_type == "distance":
|
||||
weights = list(range(1, length + 1))
|
||||
elif pooling_type == "reverse_distance":
|
||||
weights = list(range(1, length + 1))[::-1]
|
||||
elif pooling_type == "pyramid":
|
||||
if length % 2 == 0:
|
||||
weights = list(range(1, length // 2 + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
weights = list(range(1, length // 2 + 1))
|
||||
weights = weights + [length // 2] + weights[::-1]
|
||||
else:
|
||||
raise ValueError("Invalid weighted pooling type.")
|
||||
|
||||
weights = torch.tensor(weights, device=hidden_states[0].device, dtype=hidden_states[0].dtype)
|
||||
weights = weights / weights.sum()
|
||||
|
||||
pooled_embeds = sum(weight * hidden_state for weight, hidden_state in zip(weights, hidden_states))
|
||||
return pooled_embeds
|
||||
|
||||
|
||||
class SplitInferenceModule(nn.Module):
|
||||
r"""
|
||||
A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
|
||||
@@ -143,7 +200,7 @@ class SplitInferenceModule(nn.Module):
|
||||
|
||||
|
||||
class AnimateDiffFreeNoiseMixin:
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169) as used in AnimateDiff."""
|
||||
|
||||
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to enable FreeNoise in transformer blocks."""
|
||||
@@ -265,7 +322,7 @@ class AnimateDiffFreeNoiseMixin:
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
@@ -427,29 +484,13 @@ class AnimateDiffFreeNoiseMixin:
|
||||
latents = latents[:, :, :num_frames]
|
||||
return latents
|
||||
|
||||
def _lerp(
|
||||
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
num_indices = end_index - start_index + 1
|
||||
interpolated_tensors = []
|
||||
|
||||
for i in range(num_indices):
|
||||
alpha = i / (num_indices - 1)
|
||||
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
|
||||
interpolated_tensors.append(interpolated_tensor)
|
||||
|
||||
interpolated_tensors = torch.cat(interpolated_tensors)
|
||||
return interpolated_tensors
|
||||
|
||||
def enable_free_noise(
|
||||
self,
|
||||
context_length: Optional[int] = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
noise_type: str = "shuffle_context",
|
||||
prompt_interpolation_callback: Optional[
|
||||
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
] = None,
|
||||
prompt_interpolation_callback: Optional[Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable long video generation using FreeNoise.
|
||||
@@ -505,7 +546,7 @@ class AnimateDiffFreeNoiseMixin:
|
||||
self._free_noise_context_stride = context_stride
|
||||
self._free_noise_weighting_scheme = weighting_scheme
|
||||
self._free_noise_noise_type = noise_type
|
||||
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
|
||||
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or _lerp
|
||||
|
||||
if hasattr(self.unet.mid_block, "motion_modules"):
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
@@ -594,3 +635,402 @@ class AnimateDiffFreeNoiseMixin:
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||
|
||||
|
||||
class CogVideoXFreeNoiseMixin:
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169) as used in CogVideoX."""
|
||||
|
||||
def _enable_free_noise_in_block(self, transformer: CogVideoXTransformer3DModel):
|
||||
r"""Helper function to enable FreeNoise in transformer blocks."""
|
||||
patch_embed = transformer.patch_embed
|
||||
transformer.patch_embed = FreeNoiseCogVideoXPatchEmbed(
|
||||
patch_size=patch_embed.patch_size,
|
||||
in_channels=patch_embed.in_channels,
|
||||
embed_dim=patch_embed.embed_dim,
|
||||
text_embed_dim=patch_embed.text_embed_dim,
|
||||
bias=patch_embed.bias,
|
||||
sample_width=patch_embed.sample_width,
|
||||
sample_height=patch_embed.sample_height,
|
||||
sample_frames=patch_embed.sample_frames,
|
||||
temporal_compression_ratio=patch_embed.temporal_compression_ratio,
|
||||
max_text_seq_length=patch_embed.max_text_seq_length,
|
||||
spatial_interpolation_scale=patch_embed.spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=patch_embed.temporal_interpolation_scale,
|
||||
use_positional_embeddings=patch_embed.use_positional_embeddings,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
transformer.patch_embed.load_state_dict(patch_embed.state_dict(), strict=True)
|
||||
|
||||
for i in range(len(transformer.transformer_blocks)):
|
||||
block = transformer.transformer_blocks[i]
|
||||
|
||||
if isinstance(block, FreeNoiseCogVideoXBlock):
|
||||
block.set_free_noise_properties(
|
||||
self._free_noise_context_length,
|
||||
self._free_noise_context_stride,
|
||||
self._free_noise_weighting_scheme,
|
||||
self._free_noise_prompt_interpolation_callback,
|
||||
self._free_noise_prompt_pooling_callback,
|
||||
)
|
||||
else:
|
||||
transformer.transformer_blocks[i] = FreeNoiseCogVideoXBlock(
|
||||
dim=block.dim,
|
||||
num_attention_heads=block.num_attention_heads,
|
||||
attention_head_dim=block.attention_head_dim,
|
||||
time_embed_dim=block.time_embed_dim,
|
||||
dropout=block.dropout,
|
||||
activation_fn=block.activation_fn,
|
||||
attention_bias=block.attention_bias,
|
||||
qk_norm=block.qk_norm,
|
||||
norm_elementwise_affine=block.norm_elementwise_affine,
|
||||
norm_eps=block.norm_eps,
|
||||
final_dropout=block.final_dropout,
|
||||
ff_inner_dim=block.ff_inner_dim,
|
||||
ff_bias=block.ff_bias,
|
||||
attention_out_bias=block.attention_out_bias,
|
||||
context_length=self._free_noise_context_length,
|
||||
context_stride=self._free_noise_context_stride,
|
||||
weighting_scheme=self._free_noise_weighting_scheme,
|
||||
prompt_interpolation_callback=self._free_noise_prompt_interpolation_callback,
|
||||
prompt_pooling_callback=self._free_noise_prompt_pooling_callback,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
transformer.transformer_blocks[i].load_state_dict(block.state_dict(), strict=True)
|
||||
|
||||
def _disable_free_noise_in_block(self, transformer: CogVideoXTransformer3DModel):
|
||||
r"""Helper function to disable FreeNoise in transformer blocks."""
|
||||
patch_embed = transformer.patch_embed
|
||||
transformer.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_embed.patch_size,
|
||||
in_channels=patch_embed.in_channels,
|
||||
embed_dim=patch_embed.embed_dim,
|
||||
text_embed_dim=patch_embed.text_embed_dim,
|
||||
bias=patch_embed.bias,
|
||||
sample_width=patch_embed.sample_width,
|
||||
sample_height=patch_embed.sample_height,
|
||||
sample_frames=patch_embed.sample_frames,
|
||||
temporal_compression_ratio=patch_embed.temporal_compression_ratio,
|
||||
max_text_seq_length=patch_embed.max_text_seq_length,
|
||||
spatial_interpolation_scale=patch_embed.spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=patch_embed.temporal_interpolation_scale,
|
||||
use_positional_embeddings=patch_embed.use_positional_embeddings,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
transformer.patch_embed.load_state_dict(patch_embed.state_dict(), strict=True)
|
||||
|
||||
for i in range(len(transformer.transformer_blocks)):
|
||||
block = transformer.transformer_blocks[i]
|
||||
|
||||
if isinstance(block, FreeNoiseCogVideoXBlock):
|
||||
transformer.transformer_blocks[i] = CogVideoXBlock(
|
||||
dim=block.dim,
|
||||
num_attention_heads=block.num_attention_heads,
|
||||
attention_head_dim=block.attention_head_dim,
|
||||
time_embed_dim=block.time_embed_dim,
|
||||
dropout=block.dropout,
|
||||
activation_fn=block.activation_fn,
|
||||
attention_bias=block.attention_bias,
|
||||
qk_norm=block.qk_norm,
|
||||
norm_elementwise_affine=block.norm_elementwise_affine,
|
||||
norm_eps=block.norm_eps,
|
||||
final_dropout=block.final_dropout,
|
||||
ff_inner_dim=block.ff_inner_dim,
|
||||
ff_bias=block.ff_bias,
|
||||
attention_out_bias=block.attention_out_bias,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
transformer.transformer_blocks[i].load_state_dict(block.state_dict(), strict=True)
|
||||
|
||||
def _check_inputs_free_noise(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
num_frames,
|
||||
) -> None:
|
||||
if not isinstance(prompt, (str, dict)):
|
||||
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
|
||||
|
||||
if negative_prompt is not None:
|
||||
if not isinstance(negative_prompt, (str, dict)):
|
||||
raise ValueError(
|
||||
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
|
||||
)
|
||||
|
||||
if prompt_embeds is not None or negative_prompt_embeds is not None:
|
||||
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
|
||||
|
||||
frame_indices = [isinstance(x, int) for x in prompt.keys()]
|
||||
frame_prompts = [isinstance(x, str) for x in prompt.values()]
|
||||
min_frame = min(list(prompt.keys()))
|
||||
max_frame = max(list(prompt.keys()))
|
||||
|
||||
if not all(frame_indices):
|
||||
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
|
||||
if not all(frame_prompts):
|
||||
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
|
||||
if min_frame != 0:
|
||||
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
|
||||
if max_frame >= num_frames:
|
||||
raise ValueError(
|
||||
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
|
||||
)
|
||||
|
||||
def _encode_prompt_free_noise(
|
||||
self,
|
||||
prompt: Union[str, Dict[int, str]],
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
num_videos_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: Optional[int] = None,
|
||||
) -> Tuple[Dict[int, torch.Tensor], Optional[Dict[int, torch.Tensor]]]:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
# Ensure that we have a dictionary of prompts
|
||||
if isinstance(prompt, str):
|
||||
prompt = {0: prompt}
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = {0: negative_prompt}
|
||||
|
||||
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
|
||||
|
||||
# Sort the prompts based on frame indices
|
||||
prompt = dict(sorted(prompt.items()))
|
||||
negative_prompt = dict(sorted(negative_prompt.items()))
|
||||
|
||||
# Ensure that we have a prompt for the last frame index
|
||||
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
|
||||
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
|
||||
|
||||
frame_indices = list(prompt.keys())
|
||||
frame_prompts = list(prompt.values())
|
||||
frame_negative_indices = list(negative_prompt.keys())
|
||||
frame_negative_prompts = list(negative_prompt.values())
|
||||
|
||||
# Generate and bucketify positive prompts
|
||||
prompt_embeds, _ = self.encode_prompt(
|
||||
prompt=frame_prompts,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# If many prompts fall into the same latent frame bucket, pool them
|
||||
prompt_embeds_frame_buckets = collections.defaultdict(lambda: [])
|
||||
for i in range(len(frame_indices)):
|
||||
latent_frame_index = frame_indices[i] // self.transformer.config.temporal_compression_ratio
|
||||
prompt_embeds_frame_buckets[latent_frame_index].append(prompt_embeds[i].unsqueeze(0))
|
||||
|
||||
for frame_index, embeds in list(prompt_embeds_frame_buckets.items()):
|
||||
prompt_embeds_frame_buckets[frame_index] = self._free_noise_prompt_pooling_callback(embeds)
|
||||
|
||||
# Generate and bucketify negative prompts
|
||||
negative_prompt_embeds_frame_buckets = None
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
_, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=[""] * len(frame_negative_prompts),
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=frame_negative_prompts,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# If many prompts fall into the same latent frame bucket, pool them
|
||||
negative_prompt_embeds_frame_buckets = collections.defaultdict(lambda: [])
|
||||
for i in range(len(frame_negative_indices)):
|
||||
latent_frame_index = frame_negative_indices[i] // self.transformer.config.temporal_compression_ratio
|
||||
negative_prompt_embeds_frame_buckets[latent_frame_index].append(negative_prompt_embeds[i].unsqueeze(0))
|
||||
|
||||
for frame_index, embeds in list(negative_prompt_embeds_frame_buckets.items()):
|
||||
negative_prompt_embeds_frame_buckets[frame_index] = self._free_noise_prompt_pooling_callback(embeds)
|
||||
|
||||
prompt_embeds_frame_buckets = (negative_prompt_embeds_frame_buckets, prompt_embeds_frame_buckets)
|
||||
else:
|
||||
prompt_embeds_frame_buckets = (prompt_embeds_frame_buckets,)
|
||||
|
||||
return prompt_embeds_frame_buckets, negative_prompt_embeds_frame_buckets
|
||||
|
||||
def _prepare_latents_free_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
context_num_frames = (
|
||||
self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
context_num_frames,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if self._free_noise_noise_type == "random":
|
||||
return latents
|
||||
else:
|
||||
if latents.size(2) == num_frames:
|
||||
return latents
|
||||
elif latents.size(2) != self._free_noise_context_length:
|
||||
raise ValueError(
|
||||
f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
if self._free_noise_noise_type == "shuffle_context":
|
||||
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
|
||||
# ensure window is within bounds
|
||||
window_start = max(0, i - self._free_noise_context_length)
|
||||
window_end = min(num_frames, window_start + self._free_noise_context_stride)
|
||||
window_length = window_end - window_start
|
||||
|
||||
if window_length == 0:
|
||||
break
|
||||
|
||||
indices = torch.LongTensor(list(range(window_start, window_end)))
|
||||
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
|
||||
|
||||
current_start = i
|
||||
current_end = min(num_frames, current_start + window_length)
|
||||
if current_end == current_start + window_length:
|
||||
# batch of frames perfectly fits the window
|
||||
latents[:, current_start:current_end] = latents[:, shuffled_indices]
|
||||
else:
|
||||
# handle the case where the last batch of frames does not fit perfectly with the window
|
||||
prefix_length = current_end - current_start
|
||||
shuffled_indices = shuffled_indices[:prefix_length]
|
||||
latents[:, current_start:current_end] = latents[:, shuffled_indices]
|
||||
|
||||
elif self._free_noise_noise_type == "repeat_context":
|
||||
num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length
|
||||
latents = torch.cat([latents] * num_repeats, dim=1)
|
||||
|
||||
latents = latents[:, :num_frames]
|
||||
return latents
|
||||
|
||||
def enable_free_noise(
|
||||
self,
|
||||
context_length: Optional[int] = 13, # 49 pixel-space frames
|
||||
context_stride: int = 4, # 16 pixel-space frames
|
||||
weighting_scheme: str = "pyramid",
|
||||
noise_type: str = "shuffle_context",
|
||||
prompt_interpolation_callback: Optional[Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
prompt_pooling_type_or_callback: Optional[Union[str, Callable[[List[torch.Tensor]], torch.Tensor]]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable long video generation using FreeNoise.
|
||||
|
||||
Args:
|
||||
context_length (`int`, defaults to `16`, *optional*):
|
||||
The number of video frames to process at once. It's recommended to set this to the maximum frames the
|
||||
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
|
||||
adapter config is used.
|
||||
context_stride (`int`, *optional*):
|
||||
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
|
||||
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
|
||||
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
|
||||
[0, 15], [4, 19], [8, 23] (0-based indexing)
|
||||
weighting_scheme (`str`, defaults to `pyramid`):
|
||||
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
|
||||
schemes are supported currently:
|
||||
- "flat"
|
||||
Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
|
||||
- "pyramid"
|
||||
Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
|
||||
- "delayed_reverse_sawtooth"
|
||||
Performs weighted averaging with low weights for earlier frames and high-to-low weights for
|
||||
later frames: [0.01, 0.01, 3, 2, 1].
|
||||
noise_type (`str`, defaults to "shuffle_context"):
|
||||
Must be one of ["shuffle_context", "repeat_context", "random"].
|
||||
- "shuffle_context"
|
||||
Shuffles a fixed batch of `context_length` latents to create a final latent of size
|
||||
`num_frames`. This is usually the best setting for most generation scenarious. However, there
|
||||
might be visible repetition noticeable in the kinds of motion/animation generated.
|
||||
- "repeated_context"
|
||||
Repeats a fixed batch of `context_length` latents to create a final latent of size
|
||||
`num_frames`.
|
||||
- "random"
|
||||
The final latents are random without any repetition.
|
||||
"""
|
||||
|
||||
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
|
||||
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
|
||||
|
||||
# Note: we need to do this because when CogVideoX was originally introduced, it used the wrong value
|
||||
# for `sample_frames`. It should have been 13 and not 49 because the expected number of latent frames
|
||||
# in the transformer is 13.
|
||||
# Maybe try and look into what can be done before 1.0.0 release.
|
||||
if context_length % 2 == 0:
|
||||
sample_frames_context_length = context_length * self.transformer.config.temporal_compression_ratio
|
||||
else:
|
||||
sample_frames_context_length = (
|
||||
context_length - 1
|
||||
) * self.transformer.config.temporal_compression_ratio + 1
|
||||
|
||||
if sample_frames_context_length > self.transformer.config.sample_frames:
|
||||
logger.warning(
|
||||
f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
|
||||
)
|
||||
if weighting_scheme not in allowed_weighting_scheme:
|
||||
raise ValueError(
|
||||
f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}"
|
||||
)
|
||||
if noise_type not in allowed_noise_type:
|
||||
raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")
|
||||
|
||||
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
|
||||
self._free_noise_context_stride = context_stride
|
||||
self._free_noise_weighting_scheme = weighting_scheme
|
||||
self._free_noise_noise_type = noise_type
|
||||
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or _lerp
|
||||
|
||||
if prompt_pooling_type_or_callback is None:
|
||||
prompt_pooling_type_or_callback = "pyramid"
|
||||
if isinstance(prompt_pooling_type_or_callback, str):
|
||||
self._free_noise_prompt_pooling_callback = functools.partial(
|
||||
_weighted_pooling, pooling_type=prompt_pooling_type_or_callback
|
||||
)
|
||||
else:
|
||||
self._free_noise_prompt_pooling_callback = prompt_pooling_type_or_callback
|
||||
|
||||
self._enable_free_noise_in_block(self.transformer)
|
||||
|
||||
def disable_free_noise(self) -> None:
|
||||
r"""Disable the FreeNoise sampling mechanism."""
|
||||
self._free_noise_context_length = None
|
||||
|
||||
self._disable_free_noise_in_block(self.transformer)
|
||||
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||
|
||||
Reference in New Issue
Block a user