mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 05:54:24 +08:00
Compare commits
1 Commits
modular-cu
...
cogvideox-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86ac0fa6cc |
@@ -211,27 +211,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
post_patch_height = sample_height // patch_size
|
||||
post_patch_width = sample_width // patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
spatial_interpolation_scale,
|
||||
temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
self._generate_or_reuse_positional_embeddings(max_text_seq_length, sample_height, sample_width, sample_frames)
|
||||
|
||||
# 3. Time embeddings
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
@@ -268,8 +253,32 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def _generate_or_reuse_positional_embeddings(self, max_text_seq_length: int, sample_height: int, sample_width: int, sample_frames: int) -> None:
|
||||
inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
post_patch_height = sample_height // self.config.patch_size
|
||||
post_patch_width = sample_width // self.config.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.config.temporal_compression_ratio + 1
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
if getattr(self, "pos_embedding", None) is not None:
|
||||
if tuple(self.pos_embedding.shape) == (1, max_text_seq_length + num_patches, inner_dim):
|
||||
return
|
||||
del self.pos_embedding
|
||||
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.config.spatial_interpolation_scale,
|
||||
self.config.temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -295,14 +304,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
video_seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
# We need to do this because the original config `sample_frames` incorrect used 49 instead of 13. Ideally,
|
||||
# we should just pass num_frames here
|
||||
sample_frames = (num_frames - 1) * self.config.temporal_compression_ratio + 1
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
|
||||
self._generate_or_reuse_positional_embeddings(text_seq_length, height, width, sample_frames)
|
||||
self.pos_embedding = self.pos_embedding.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
print(self.pos_embedding.shape, text_seq_length, video_seq_length)
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + video_seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
|
||||
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
|
||||
encoder_hidden_states = hidden_states[:, : text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length :]
|
||||
|
||||
# 5. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -534,10 +534,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
assert (
|
||||
num_frames <= 48 and num_frames % fps == 0 and fps == 8
|
||||
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
@@ -593,7 +589,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
num_frames += 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
|
||||
Reference in New Issue
Block a user