Compare commits

...

1 Commits

Author SHA1 Message Date
Aryan
86ac0fa6cc cogvideox dynamic pos embeds 2024-08-13 00:14:39 +02:00
2 changed files with 37 additions and 26 deletions

View File

@@ -211,27 +211,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
self._generate_or_reuse_positional_embeddings(max_text_seq_length, sample_height, sample_width, sample_frames)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
@@ -268,8 +253,32 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
self.gradient_checkpointing = value
def _generate_or_reuse_positional_embeddings(self, max_text_seq_length: int, sample_height: int, sample_width: int, sample_frames: int) -> None:
inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
post_patch_height = sample_height // self.config.patch_size
post_patch_width = sample_width // self.config.patch_size
post_time_compression_frames = (sample_frames - 1) // self.config.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
if getattr(self, "pos_embedding", None) is not None:
if tuple(self.pos_embedding.shape) == (1, max_text_seq_length + num_patches, inner_dim):
return
del self.pos_embedding
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.config.spatial_interpolation_scale,
self.config.temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
def forward(
self,
@@ -295,14 +304,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
seq_length = height * width * num_frames // (self.config.patch_size**2)
text_seq_length = encoder_hidden_states.size(1)
video_seq_length = height * width * num_frames // (self.config.patch_size**2)
# We need to do this because the original config `sample_frames` incorrect used 49 instead of 13. Ideally,
# we should just pass num_frames here
sample_frames = (num_frames - 1) * self.config.temporal_compression_ratio + 1
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
self._generate_or_reuse_positional_embeddings(text_seq_length, height, width, sample_frames)
self.pos_embedding = self.pos_embedding.to(device=hidden_states.device, dtype=hidden_states.dtype)
print(self.pos_embedding.shape, text_seq_length, video_seq_length)
pos_embeds = self.pos_embedding[:, : text_seq_length + video_seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
encoder_hidden_states = hidden_states[:, : text_seq_length]
hidden_states = hidden_states[:, text_seq_length :]
# 5. Transformer blocks
for i, block in enumerate(self.transformer_blocks):

View File

@@ -534,10 +534,6 @@ class CogVideoXPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
assert (
num_frames <= 48 and num_frames % fps == 0 and fps == 8
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -593,7 +589,6 @@ class CogVideoXPipeline(DiffusionPipeline):
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,