|
|
|
|
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config
|
|
|
|
|
from ...loaders.single_file_model import FromOriginalModelMixin
|
|
|
|
|
from ...utils import logging
|
|
|
|
|
from ...utils import deprecate, logging
|
|
|
|
|
from ...utils.accelerate_utils import apply_forward_hook
|
|
|
|
|
from ..activations import get_activation
|
|
|
|
|
from ..downsampling import CogVideoXDownsample3D
|
|
|
|
|
@@ -1086,9 +1086,23 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
|
|
|
|
|
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
|
|
|
|
|
|
|
|
|
|
self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
|
|
|
|
|
self.temporal_compression_ratio = temporal_compression_ratio
|
|
|
|
|
|
|
|
|
|
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
|
|
|
|
# to perform decoding of a single video latent at a time.
|
|
|
|
|
self.use_slicing = False
|
|
|
|
|
|
|
|
|
|
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
|
|
|
|
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
|
|
|
|
# intermediate tiles together, the memory requirement can be lowered.
|
|
|
|
|
self.use_tiling = False
|
|
|
|
|
|
|
|
|
|
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
|
|
|
|
|
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
|
|
|
|
|
self.use_framewise_encoding = True
|
|
|
|
|
self.use_framewise_decoding = True
|
|
|
|
|
|
|
|
|
|
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
|
|
|
|
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
|
|
|
|
# If you decode X latent frames together, the number of output frames is:
|
|
|
|
|
@@ -1109,18 +1123,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
self.num_sample_frames_batch_size = 8
|
|
|
|
|
|
|
|
|
|
# We make the minimum height and width of sample for tiling half that of the generally supported
|
|
|
|
|
self.tile_sample_min_height = sample_height // 2
|
|
|
|
|
self.tile_sample_min_width = sample_width // 2
|
|
|
|
|
self.tile_latent_min_height = int(
|
|
|
|
|
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
|
|
|
|
)
|
|
|
|
|
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
|
|
|
|
self.tile_sample_min_height = 256
|
|
|
|
|
self.tile_sample_min_width = 256
|
|
|
|
|
|
|
|
|
|
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
|
|
|
|
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
|
|
|
|
# and so the tiling implementation has only been tested on those specific resolutions.
|
|
|
|
|
self.tile_overlap_factor_height = 1 / 6
|
|
|
|
|
self.tile_overlap_factor_width = 1 / 5
|
|
|
|
|
self.tile_sample_stride_height = 192
|
|
|
|
|
self.tile_sample_stride_width = 192
|
|
|
|
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
|
|
|
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
|
|
|
|
@@ -1132,6 +1139,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
tile_sample_min_width: Optional[int] = None,
|
|
|
|
|
tile_overlap_factor_height: Optional[float] = None,
|
|
|
|
|
tile_overlap_factor_width: Optional[float] = None,
|
|
|
|
|
tile_sample_stride_height: Optional[float] = None,
|
|
|
|
|
tile_sample_stride_width: Optional[float] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
r"""
|
|
|
|
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
|
|
|
|
@@ -1143,24 +1152,36 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
The minimum height required for a sample to be separated into tiles across the height dimension.
|
|
|
|
|
tile_sample_min_width (`int`, *optional*):
|
|
|
|
|
The minimum width required for a sample to be separated into tiles across the width dimension.
|
|
|
|
|
tile_overlap_factor_height (`int`, *optional*):
|
|
|
|
|
tile_sample_stride_height (`int`, *optional*):
|
|
|
|
|
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
|
|
|
|
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
|
|
|
|
value might cause more tiles to be processed leading to slow down of the decoding process.
|
|
|
|
|
tile_overlap_factor_width (`int`, *optional*):
|
|
|
|
|
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
|
|
|
|
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
|
|
|
|
value might cause more tiles to be processed leading to slow down of the decoding process.
|
|
|
|
|
no tiling artifacts produced across the height dimension.
|
|
|
|
|
tile_sample_stride_width (`int`, *optional*):
|
|
|
|
|
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
|
|
|
|
artifacts produced across the width dimension.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if tile_overlap_factor_height is not None or tile_overlap_factor_width is not None:
|
|
|
|
|
deprecate(
|
|
|
|
|
"tile_overlap_factor",
|
|
|
|
|
"1.0.0",
|
|
|
|
|
"The parameters `tile_overlap_factor_height` and `tile_overlap_factor_width` are deprecated and will be ignored. Please use `tile_sample_stride_height` and `tile_sample_stride_width` instead. For now, we will use these flags automatically, if passed, without breaking the existing behaviour.",
|
|
|
|
|
)
|
|
|
|
|
tile_sample_stride_height = (
|
|
|
|
|
int((1 - tile_overlap_factor_height) * self.tile_sample_min_height)
|
|
|
|
|
// self.spatial_compression_ratio
|
|
|
|
|
* self.spatial_compression_ratio
|
|
|
|
|
)
|
|
|
|
|
tile_sample_stride_width = (
|
|
|
|
|
int((1 - tile_overlap_factor_width) * self.tile_sample_min_width)
|
|
|
|
|
// self.spatial_compression_ratio
|
|
|
|
|
* self.spatial_compression_ratio
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.use_tiling = True
|
|
|
|
|
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
|
|
|
|
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
|
|
|
|
self.tile_latent_min_height = int(
|
|
|
|
|
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
|
|
|
|
)
|
|
|
|
|
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
|
|
|
|
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
|
|
|
|
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
|
|
|
|
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
|
|
|
|
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
|
|
|
|
|
|
|
|
|
def disable_tiling(self) -> None:
|
|
|
|
|
r"""
|
|
|
|
|
@@ -1189,24 +1210,23 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
|
|
|
|
return self.tiled_encode(x)
|
|
|
|
|
|
|
|
|
|
frame_batch_size = self.num_sample_frames_batch_size
|
|
|
|
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
|
|
|
|
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
|
|
|
|
num_batches = max(num_frames // frame_batch_size, 1)
|
|
|
|
|
conv_cache = None
|
|
|
|
|
enc = []
|
|
|
|
|
if self.use_framewise_encoding:
|
|
|
|
|
enc = []
|
|
|
|
|
conv_cache = None
|
|
|
|
|
|
|
|
|
|
for i in range(num_batches):
|
|
|
|
|
remaining_frames = num_frames % frame_batch_size
|
|
|
|
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
|
|
|
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
|
|
|
|
x_intermediate = x[:, :, start_frame:end_frame]
|
|
|
|
|
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
|
|
|
|
for i in range(0, num_frames, self.num_sample_frames_batch_size):
|
|
|
|
|
x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size]
|
|
|
|
|
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
|
|
|
|
if self.quant_conv is not None:
|
|
|
|
|
x_intermediate = self.quant_conv(x_intermediate)
|
|
|
|
|
enc.append(x_intermediate)
|
|
|
|
|
|
|
|
|
|
enc = torch.cat(enc, dim=2)
|
|
|
|
|
else:
|
|
|
|
|
enc, _ = self.encoder(x)
|
|
|
|
|
if self.quant_conv is not None:
|
|
|
|
|
x_intermediate = self.quant_conv(x_intermediate)
|
|
|
|
|
enc.append(x_intermediate)
|
|
|
|
|
enc = self.quant_conv(enc)
|
|
|
|
|
|
|
|
|
|
enc = torch.cat(enc, dim=2)
|
|
|
|
|
return enc
|
|
|
|
|
|
|
|
|
|
@apply_forward_hook
|
|
|
|
|
@@ -1239,26 +1259,28 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
|
|
|
|
|
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
|
|
|
|
batch_size, num_channels, num_frames, height, width = z.shape
|
|
|
|
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
|
|
|
|
|
|
|
|
|
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
|
|
|
|
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
|
|
|
|
return self.tiled_decode(z, return_dict=return_dict)
|
|
|
|
|
|
|
|
|
|
frame_batch_size = self.num_latent_frames_batch_size
|
|
|
|
|
num_batches = max(num_frames // frame_batch_size, 1)
|
|
|
|
|
conv_cache = None
|
|
|
|
|
dec = []
|
|
|
|
|
if self.use_framewise_decoding:
|
|
|
|
|
dec = []
|
|
|
|
|
conv_cache = None
|
|
|
|
|
|
|
|
|
|
for i in range(num_batches):
|
|
|
|
|
remaining_frames = num_frames % frame_batch_size
|
|
|
|
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
|
|
|
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
|
|
|
|
z_intermediate = z[:, :, start_frame:end_frame]
|
|
|
|
|
for i in range(0, num_frames, self.num_latent_frames_batch_size):
|
|
|
|
|
z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
|
|
|
|
|
if self.post_quant_conv is not None:
|
|
|
|
|
z_intermediate = self.post_quant_conv(z_intermediate)
|
|
|
|
|
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
|
|
|
|
dec.append(z_intermediate)
|
|
|
|
|
|
|
|
|
|
dec = torch.cat(dec, dim=2)
|
|
|
|
|
else:
|
|
|
|
|
if self.post_quant_conv is not None:
|
|
|
|
|
z_intermediate = self.post_quant_conv(z_intermediate)
|
|
|
|
|
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
|
|
|
|
dec.append(z_intermediate)
|
|
|
|
|
|
|
|
|
|
dec = torch.cat(dec, dim=2)
|
|
|
|
|
z = self.post_quant_conv(z)
|
|
|
|
|
dec, _ = self.decoder(z)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return (dec,)
|
|
|
|
|
@@ -1324,44 +1346,48 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
"""
|
|
|
|
|
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
|
|
|
|
batch_size, num_channels, num_frames, height, width = x.shape
|
|
|
|
|
latent_height = height // self.spatial_compression_ratio
|
|
|
|
|
latent_width = width // self.spatial_compression_ratio
|
|
|
|
|
|
|
|
|
|
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
|
|
|
|
|
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
|
|
|
|
|
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
|
|
|
|
|
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
|
|
|
|
|
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
|
|
|
|
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
|
|
|
|
frame_batch_size = self.num_sample_frames_batch_size
|
|
|
|
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
|
|
|
|
|
|
|
|
|
blend_height = tile_latent_min_height - tile_latent_stride_height
|
|
|
|
|
blend_width = tile_latent_min_width - tile_latent_stride_width
|
|
|
|
|
|
|
|
|
|
# Split x into overlapping tiles and encode them separately.
|
|
|
|
|
# The tiles have an overlap to avoid seams between tiles.
|
|
|
|
|
rows = []
|
|
|
|
|
for i in range(0, height, overlap_height):
|
|
|
|
|
for i in range(0, height, self.tile_sample_stride_height):
|
|
|
|
|
row = []
|
|
|
|
|
for j in range(0, width, overlap_width):
|
|
|
|
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
|
|
|
|
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
|
|
|
|
num_batches = max(num_frames // frame_batch_size, 1)
|
|
|
|
|
conv_cache = None
|
|
|
|
|
time = []
|
|
|
|
|
for j in range(0, width, self.tile_sample_stride_width):
|
|
|
|
|
if self.use_framewise_encoding:
|
|
|
|
|
time = []
|
|
|
|
|
conv_cache = None
|
|
|
|
|
|
|
|
|
|
for k in range(num_batches):
|
|
|
|
|
remaining_frames = num_frames % frame_batch_size
|
|
|
|
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
|
|
|
|
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
|
|
|
|
tile = x[
|
|
|
|
|
:,
|
|
|
|
|
:,
|
|
|
|
|
start_frame:end_frame,
|
|
|
|
|
i : i + self.tile_sample_min_height,
|
|
|
|
|
j : j + self.tile_sample_min_width,
|
|
|
|
|
]
|
|
|
|
|
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
|
|
|
|
for k in range(0, num_frames, self.num_sample_frames_batch_size):
|
|
|
|
|
tile = x[
|
|
|
|
|
:,
|
|
|
|
|
:,
|
|
|
|
|
k : k + self.num_sample_frames_batch_size,
|
|
|
|
|
i : i + self.tile_sample_min_height,
|
|
|
|
|
j : j + self.tile_sample_min_width,
|
|
|
|
|
]
|
|
|
|
|
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
|
|
|
|
if self.quant_conv is not None:
|
|
|
|
|
tile = self.quant_conv(tile)
|
|
|
|
|
time.append(tile)
|
|
|
|
|
|
|
|
|
|
time = torch.cat(time, dim=2)
|
|
|
|
|
else:
|
|
|
|
|
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
|
|
|
|
time, _ = self.encoder(tile)
|
|
|
|
|
if self.quant_conv is not None:
|
|
|
|
|
tile = self.quant_conv(tile)
|
|
|
|
|
time.append(tile)
|
|
|
|
|
time = self.quant_conv(time)
|
|
|
|
|
|
|
|
|
|
row.append(torch.cat(time, dim=2))
|
|
|
|
|
row.append(time)
|
|
|
|
|
rows.append(row)
|
|
|
|
|
|
|
|
|
|
result_rows = []
|
|
|
|
|
@@ -1371,13 +1397,13 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
# blend the above tile and the left tile
|
|
|
|
|
# to the current tile and add the current tile to the result row
|
|
|
|
|
if i > 0:
|
|
|
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
|
|
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
|
|
|
|
if j > 0:
|
|
|
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
|
|
|
|
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
|
|
|
|
tile = self.blend_h(row[j - 1], tile, blend_width)
|
|
|
|
|
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
|
|
|
|
result_rows.append(torch.cat(result_row, dim=4))
|
|
|
|
|
|
|
|
|
|
enc = torch.cat(result_rows, dim=3)
|
|
|
|
|
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
|
|
|
|
return enc
|
|
|
|
|
|
|
|
|
|
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
|
|
|
|
@@ -1405,58 +1431,63 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
|
|
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
|
|
|
|
|
|
|
|
|
batch_size, num_channels, num_frames, height, width = z.shape
|
|
|
|
|
sample_height = height * self.spatial_compression_ratio
|
|
|
|
|
sample_width = width * self.spatial_compression_ratio
|
|
|
|
|
|
|
|
|
|
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
|
|
|
|
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
|
|
|
|
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
|
|
|
|
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
|
|
|
|
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
|
|
|
|
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
|
|
|
|
frame_batch_size = self.num_latent_frames_batch_size
|
|
|
|
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
|
|
|
|
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
|
|
|
|
|
|
|
|
|
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
|
|
|
|
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
|
|
|
|
|
|
|
|
|
# Split z into overlapping tiles and decode them separately.
|
|
|
|
|
# The tiles have an overlap to avoid seams between tiles.
|
|
|
|
|
rows = []
|
|
|
|
|
for i in range(0, height, overlap_height):
|
|
|
|
|
for i in range(0, height, tile_latent_stride_height):
|
|
|
|
|
row = []
|
|
|
|
|
for j in range(0, width, overlap_width):
|
|
|
|
|
num_batches = max(num_frames // frame_batch_size, 1)
|
|
|
|
|
conv_cache = None
|
|
|
|
|
time = []
|
|
|
|
|
for j in range(0, width, tile_latent_stride_width):
|
|
|
|
|
if self.use_framewise_decoding:
|
|
|
|
|
time = []
|
|
|
|
|
conv_cache = None
|
|
|
|
|
|
|
|
|
|
for k in range(num_batches):
|
|
|
|
|
remaining_frames = num_frames % frame_batch_size
|
|
|
|
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
|
|
|
|
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
|
|
|
|
tile = z[
|
|
|
|
|
:,
|
|
|
|
|
:,
|
|
|
|
|
start_frame:end_frame,
|
|
|
|
|
i : i + self.tile_latent_min_height,
|
|
|
|
|
j : j + self.tile_latent_min_width,
|
|
|
|
|
]
|
|
|
|
|
for k in range(0, num_frames, self.num_latent_frames_batch_size):
|
|
|
|
|
tile = z[
|
|
|
|
|
:,
|
|
|
|
|
:,
|
|
|
|
|
k : k + self.num_latent_frames_batch_size,
|
|
|
|
|
i : i + tile_latent_min_height,
|
|
|
|
|
j : j + tile_latent_min_width,
|
|
|
|
|
]
|
|
|
|
|
if self.post_quant_conv is not None:
|
|
|
|
|
tile = self.post_quant_conv(tile)
|
|
|
|
|
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
|
|
|
|
time.append(tile)
|
|
|
|
|
|
|
|
|
|
time = torch.cat(time, dim=2)
|
|
|
|
|
else:
|
|
|
|
|
tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
|
|
|
|
if self.post_quant_conv is not None:
|
|
|
|
|
tile = self.post_quant_conv(tile)
|
|
|
|
|
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
|
|
|
|
time.append(tile)
|
|
|
|
|
time, _ = self.decoder(tile)
|
|
|
|
|
|
|
|
|
|
row.append(torch.cat(time, dim=2))
|
|
|
|
|
row.append(time)
|
|
|
|
|
rows.append(row)
|
|
|
|
|
|
|
|
|
|
result_rows = []
|
|
|
|
|
for i, row in enumerate(rows):
|
|
|
|
|
result_row = []
|
|
|
|
|
for j, tile in enumerate(row):
|
|
|
|
|
# blend the above tile and the left tile
|
|
|
|
|
# to the current tile and add the current tile to the result row
|
|
|
|
|
# Blend the above tile and the left tile to the current tile and add the current tile to the result row
|
|
|
|
|
if i > 0:
|
|
|
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
|
|
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
|
|
|
|
if j > 0:
|
|
|
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
|
|
|
|
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
|
|
|
|
tile = self.blend_h(row[j - 1], tile, blend_width)
|
|
|
|
|
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
|
|
|
|
result_rows.append(torch.cat(result_row, dim=4))
|
|
|
|
|
|
|
|
|
|
dec = torch.cat(result_rows, dim=3)
|
|
|
|
|
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return (dec,)
|
|
|
|
|
|