Compare commits

...

5 Commits

Author SHA1 Message Date
Aryan
da9e4babe5 Merge branch 'main' into refactor-cogvideox-vae 2024-11-22 20:33:03 +05:30
Aryan
b49ae8cac2 apply suggestions from review 2024-11-13 00:41:54 +01:00
Aryan
2f1f43c59d fight tests 2024-11-11 19:47:21 +01:00
Aryan
af00830898 make style 2024-11-11 03:50:55 +01:00
Aryan
01d6a1b89f refactor 2024-11-11 03:50:35 +01:00
5 changed files with 159 additions and 128 deletions

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils import deprecate, logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..downsampling import CogVideoXDownsample3D
@@ -1086,9 +1086,23 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
self.use_framewise_encoding = True
self.use_framewise_decoding = True
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
# recommended because the temporal parts of the VAE, here, are tricky to understand.
# If you decode X latent frames together, the number of output frames is:
@@ -1109,18 +1123,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.num_sample_frames_batch_size = 8
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
@@ -1132,6 +1139,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1143,24 +1152,36 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
if tile_overlap_factor_height is not None or tile_overlap_factor_width is not None:
deprecate(
"tile_overlap_factor",
"1.0.0",
"The parameters `tile_overlap_factor_height` and `tile_overlap_factor_width` are deprecated and will be ignored. Please use `tile_sample_stride_height` and `tile_sample_stride_width` instead. For now, we will use these flags automatically, if passed, without breaking the existing behaviour.",
)
tile_sample_stride_height = (
int((1 - tile_overlap_factor_height) * self.tile_sample_min_height)
// self.spatial_compression_ratio
* self.spatial_compression_ratio
)
tile_sample_stride_width = (
int((1 - tile_overlap_factor_width) * self.tile_sample_min_width)
// self.spatial_compression_ratio
* self.spatial_compression_ratio
)
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
@@ -1189,24 +1210,23 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
# As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
enc = []
if self.use_framewise_encoding:
enc = []
conv_cache = None
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
for i in range(0, num_frames, self.num_sample_frames_batch_size):
x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size]
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
enc = torch.cat(enc, dim=2)
else:
enc, _ = self.encoder(x)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
enc = self.quant_conv(enc)
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook
@@ -1239,26 +1259,28 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
dec = []
if self.use_framewise_decoding:
dec = []
conv_cache = None
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
for i in range(0, num_frames, self.num_latent_frames_batch_size):
z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
else:
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
z = self.post_quant_conv(z)
dec, _ = self.decoder(z)
if not return_dict:
return (dec,)
@@ -1324,44 +1346,48 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
# For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
frame_batch_size = self.num_sample_frames_batch_size
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
# As the extra single frame is handled inside the loop, it is not required to round up here.
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
for j in range(0, width, self.tile_sample_stride_width):
if self.use_framewise_encoding:
time = []
conv_cache = None
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = x[
:,
:,
start_frame:end_frame,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
for k in range(0, num_frames, self.num_sample_frames_batch_size):
tile = x[
:,
:,
k : k + self.num_sample_frames_batch_size,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
time = torch.cat(time, dim=2)
else:
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
time, _ = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
time = self.quant_conv(time)
row.append(torch.cat(time, dim=2))
row.append(time)
rows.append(row)
result_rows = []
@@ -1371,13 +1397,13 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
@@ -1405,58 +1431,63 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, overlap_width):
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
for j in range(0, width, tile_latent_stride_width):
if self.use_framewise_decoding:
time = []
conv_cache = None
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
for k in range(0, num_frames, self.num_latent_frames_batch_size):
tile = z[
:,
:,
k : k + self.num_latent_frames_batch_size,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile)
time = torch.cat(time, dim=2)
else:
tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile)
time, _ = self.decoder(tile)
row.append(torch.cat(time, dim=2))
row.append(time)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
# Blend the above tile and the left tile to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)

View File

@@ -268,8 +268,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128

View File

@@ -272,8 +272,8 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128

View File

@@ -291,8 +291,8 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128

View File

@@ -273,8 +273,8 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128