Compare commits

...

5 Commits

Author SHA1 Message Date
Aryan
da9e4babe5 Merge branch 'main' into refactor-cogvideox-vae 2024-11-22 20:33:03 +05:30
Aryan
b49ae8cac2 apply suggestions from review 2024-11-13 00:41:54 +01:00
Aryan
2f1f43c59d fight tests 2024-11-11 19:47:21 +01:00
Aryan
af00830898 make style 2024-11-11 03:50:55 +01:00
Aryan
01d6a1b89f refactor 2024-11-11 03:50:35 +01:00
5 changed files with 159 additions and 128 deletions

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging from ...utils import deprecate, logging
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation from ..activations import get_activation
from ..downsampling import CogVideoXDownsample3D from ..downsampling import CogVideoXDownsample3D
@@ -1086,9 +1086,23 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False self.use_tiling = False
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
self.use_framewise_encoding = True
self.use_framewise_decoding = True
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
# recommended because the temporal parts of the VAE, here, are tricky to understand. # recommended because the temporal parts of the VAE, here, are tricky to understand.
# If you decode X latent frames together, the number of output frames is: # If you decode X latent frames together, the number of output frames is:
@@ -1109,18 +1123,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.num_sample_frames_batch_size = 8 self.num_sample_frames_batch_size = 8
# We make the minimum height and width of sample for tiling half that of the generally supported # We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2 self.tile_sample_min_height = 256
self.tile_sample_min_width = sample_width // 2 self.tile_sample_min_width = 256
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for self.tile_sample_stride_height = 192
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX self.tile_sample_stride_width = 192
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
@@ -1132,6 +1139,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
tile_sample_min_width: Optional[int] = None, tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None, tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None, tile_overlap_factor_width: Optional[float] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None: ) -> None:
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1143,24 +1152,36 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The minimum height required for a sample to be separated into tiles across the height dimension. The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*): tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension. The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*): tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher no tiling artifacts produced across the height dimension.
value might cause more tiles to be processed leading to slow down of the decoding process. tile_sample_stride_width (`int`, *optional*):
tile_overlap_factor_width (`int`, *optional*): The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there artifacts produced across the width dimension.
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
""" """
if tile_overlap_factor_height is not None or tile_overlap_factor_width is not None:
deprecate(
"tile_overlap_factor",
"1.0.0",
"The parameters `tile_overlap_factor_height` and `tile_overlap_factor_width` are deprecated and will be ignored. Please use `tile_sample_stride_height` and `tile_sample_stride_width` instead. For now, we will use these flags automatically, if passed, without breaking the existing behaviour.",
)
tile_sample_stride_height = (
int((1 - tile_overlap_factor_height) * self.tile_sample_min_height)
// self.spatial_compression_ratio
* self.spatial_compression_ratio
)
tile_sample_stride_width = (
int((1 - tile_overlap_factor_width) * self.tile_sample_min_width)
// self.spatial_compression_ratio
* self.spatial_compression_ratio
)
self.use_tiling = True self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None: def disable_tiling(self) -> None:
r""" r"""
@@ -1189,24 +1210,23 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x) return self.tiled_encode(x)
frame_batch_size = self.num_sample_frames_batch_size if self.use_framewise_encoding:
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. enc = []
# As the extra single frame is handled inside the loop, it is not required to round up here. conv_cache = None
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
enc = []
for i in range(num_batches): for i in range(0, num_frames, self.num_sample_frames_batch_size):
remaining_frames = num_frames % frame_batch_size x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size]
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
end_frame = frame_batch_size * (i + 1) + remaining_frames if self.quant_conv is not None:
x_intermediate = x[:, :, start_frame:end_frame] x_intermediate = self.quant_conv(x_intermediate)
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) enc.append(x_intermediate)
enc = torch.cat(enc, dim=2)
else:
enc, _ = self.encoder(x)
if self.quant_conv is not None: if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate) enc = self.quant_conv(enc)
enc.append(x_intermediate)
enc = torch.cat(enc, dim=2)
return enc return enc
@apply_forward_hook @apply_forward_hook
@@ -1239,26 +1259,28 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict) return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size if self.use_framewise_decoding:
num_batches = max(num_frames // frame_batch_size, 1) dec = []
conv_cache = None conv_cache = None
dec = []
for i in range(num_batches): for i in range(0, num_frames, self.num_latent_frames_batch_size):
remaining_frames = num_frames % frame_batch_size z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) if self.post_quant_conv is not None:
end_frame = frame_batch_size * (i + 1) + remaining_frames z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = z[:, :, start_frame:end_frame] z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
else:
if self.post_quant_conv is not None: if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate) z = self.post_quant_conv(z)
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) dec, _ = self.decoder(z)
dec.append(z_intermediate)
dec = torch.cat(dec, dim=2)
if not return_dict: if not return_dict:
return (dec,) return (dec,)
@@ -1324,44 +1346,48 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
""" """
# For a rough memory estimate, take a look at the `tiled_decode` method. # For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape batch_size, num_channels, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width blend_height = tile_latent_min_height - tile_latent_stride_height
frame_batch_size = self.num_sample_frames_batch_size blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately. # Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles. # The tiles have an overlap to avoid seams between tiles.
rows = [] rows = []
for i in range(0, height, overlap_height): for i in range(0, height, self.tile_sample_stride_height):
row = [] row = []
for j in range(0, width, overlap_width): for j in range(0, width, self.tile_sample_stride_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. if self.use_framewise_encoding:
# As the extra single frame is handled inside the loop, it is not required to round up here. time = []
num_batches = max(num_frames // frame_batch_size, 1) conv_cache = None
conv_cache = None
time = []
for k in range(num_batches): for k in range(0, num_frames, self.num_sample_frames_batch_size):
remaining_frames = num_frames % frame_batch_size tile = x[
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) :,
end_frame = frame_batch_size * (k + 1) + remaining_frames :,
tile = x[ k : k + self.num_sample_frames_batch_size,
:, i : i + self.tile_sample_min_height,
:, j : j + self.tile_sample_min_width,
start_frame:end_frame, ]
i : i + self.tile_sample_min_height, tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
j : j + self.tile_sample_min_width, if self.quant_conv is not None:
] tile = self.quant_conv(tile)
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) time.append(tile)
time = torch.cat(time, dim=2)
else:
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
time, _ = self.encoder(tile)
if self.quant_conv is not None: if self.quant_conv is not None:
tile = self.quant_conv(tile) time = self.quant_conv(time)
time.append(tile)
row.append(torch.cat(time, dim=2)) row.append(time)
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
@@ -1371,13 +1397,13 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# blend the above tile and the left tile # blend the above tile and the left tile
# to the current tile and add the current tile to the result row # to the current tile and add the current tile to the result row
if i > 0: if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0: if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width) tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=4)) result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3) enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
@@ -1405,58 +1431,63 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape batch_size, num_channels, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
frame_batch_size = self.num_latent_frames_batch_size blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately. # Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles. # The tiles have an overlap to avoid seams between tiles.
rows = [] rows = []
for i in range(0, height, overlap_height): for i in range(0, height, tile_latent_stride_height):
row = [] row = []
for j in range(0, width, overlap_width): for j in range(0, width, tile_latent_stride_width):
num_batches = max(num_frames // frame_batch_size, 1) if self.use_framewise_decoding:
conv_cache = None time = []
time = [] conv_cache = None
for k in range(num_batches): for k in range(0, num_frames, self.num_latent_frames_batch_size):
remaining_frames = num_frames % frame_batch_size tile = z[
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) :,
end_frame = frame_batch_size * (k + 1) + remaining_frames :,
tile = z[ k : k + self.num_latent_frames_batch_size,
:, i : i + tile_latent_min_height,
:, j : j + tile_latent_min_width,
start_frame:end_frame, ]
i : i + self.tile_latent_min_height, if self.post_quant_conv is not None:
j : j + self.tile_latent_min_width, tile = self.post_quant_conv(tile)
] tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile)
time = torch.cat(time, dim=2)
else:
tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
if self.post_quant_conv is not None: if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile) tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) time, _ = self.decoder(tile)
time.append(tile)
row.append(torch.cat(time, dim=2)) row.append(time)
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
for i, row in enumerate(rows): for i, row in enumerate(rows):
result_row = [] result_row = []
for j, tile in enumerate(row): for j, tile in enumerate(row):
# blend the above tile and the left tile # Blend the above tile and the left tile to the current tile and add the current tile to the result row
# to the current tile and add the current tile to the result row
if i > 0: if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0: if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width) tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=4)) result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict: if not return_dict:
return (dec,) return (dec,)

View File

@@ -268,8 +268,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipe.vae.enable_tiling( pipe.vae.enable_tiling(
tile_sample_min_height=96, tile_sample_min_height=96,
tile_sample_min_width=96, tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12, tile_sample_stride_height=64,
tile_overlap_factor_width=1 / 12, tile_sample_stride_width=64,
) )
inputs = self.get_dummy_inputs(generator_device) inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128 inputs["height"] = inputs["width"] = 128

View File

@@ -272,8 +272,8 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
pipe.vae.enable_tiling( pipe.vae.enable_tiling(
tile_sample_min_height=96, tile_sample_min_height=96,
tile_sample_min_width=96, tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12, tile_sample_stride_height=64,
tile_overlap_factor_width=1 / 12, tile_sample_stride_width=64,
) )
inputs = self.get_dummy_inputs(generator_device) inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128 inputs["height"] = inputs["width"] = 128

View File

@@ -291,8 +291,8 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
pipe.vae.enable_tiling( pipe.vae.enable_tiling(
tile_sample_min_height=96, tile_sample_min_height=96,
tile_sample_min_width=96, tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12, tile_sample_stride_height=64,
tile_overlap_factor_width=1 / 12, tile_sample_stride_width=64,
) )
inputs = self.get_dummy_inputs(generator_device) inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128 inputs["height"] = inputs["width"] = 128

View File

@@ -273,8 +273,8 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
pipe.vae.enable_tiling( pipe.vae.enable_tiling(
tile_sample_min_height=96, tile_sample_min_height=96,
tile_sample_min_width=96, tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12, tile_sample_stride_height=64,
tile_overlap_factor_width=1 / 12, tile_sample_stride_width=64,
) )
inputs = self.get_dummy_inputs(generator_device) inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128 inputs["height"] = inputs["width"] = 128