Compare commits

...

2 Commits

Author SHA1 Message Date
Sayak Paul
5f6a41ed02 Merge branch 'main' into cogvideox/one-shot-decoding 2024-10-06 10:19:06 +04:00
Aryan
91fbe16b63 update 2024-09-29 13:53:36 +02:00

View File

@@ -117,6 +117,8 @@ class CogVideoXCausalConv3d(nn.Module):
dilation=dilation,
)
self.return_conv_cache = True
def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
@@ -128,7 +130,10 @@ class CogVideoXCausalConv3d(nn.Module):
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
if self.return_conv_cache:
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
else:
conv_cache = None
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
@@ -1079,6 +1084,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False
self.use_tiling = False
self.use_framewise_batching = True
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
# recommended because the temporal parts of the VAE, here, are tricky to understand.
@@ -1174,6 +1180,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
self.use_slicing = False
def enable_framewise_batching(self) -> None:
self.use_framewise_batching = True
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
module.return_conv_cache = True
def disable_framewise_batching(self) -> None:
self.use_framewise_batching = False
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
module.return_conv_cache = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -1184,19 +1204,26 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
conv_cache = None
enc = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
if self.use_framewise_batching:
enc = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
enc = torch.cat(enc, dim=2)
else:
enc, _ = self.encoder(x, conv_cache=conv_cache)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
enc = self.quant_conv(enc)
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook
@@ -1236,19 +1263,25 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
frame_batch_size = self.num_latent_frames_batch_size
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
dec = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.use_framewise_batching:
dec = []
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
else:
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
dec = self.post_quant_conv(z)
dec, _ = self.decoder(z, conv_cache=conv_cache)
if not return_dict:
return (dec,)