mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 01:14:47 +08:00
Compare commits
2 Commits
fix-bnb-te
...
cogvideox/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f6a41ed02 | ||
|
|
91fbe16b63 |
@@ -117,6 +117,8 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
self.return_conv_cache = True
|
||||
|
||||
def fake_context_parallel_forward(
|
||||
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
@@ -128,7 +130,10 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
|
||||
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
if self.return_conv_cache:
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
else:
|
||||
conv_cache = None
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
@@ -1079,6 +1084,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
self.use_framewise_batching = True
|
||||
|
||||
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
||||
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
||||
@@ -1174,6 +1180,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def enable_framewise_batching(self) -> None:
|
||||
self.use_framewise_batching = True
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
module.return_conv_cache = True
|
||||
|
||||
def disable_framewise_batching(self) -> None:
|
||||
self.use_framewise_batching = False
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
module.return_conv_cache = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
@@ -1184,19 +1204,26 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
conv_cache = None
|
||||
enc = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
||||
if self.use_framewise_batching:
|
||||
enc = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
|
||||
enc = torch.cat(enc, dim=2)
|
||||
else:
|
||||
enc, _ = self.encoder(x, conv_cache=conv_cache)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
enc = torch.cat(enc, dim=2)
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
@@ -1236,19 +1263,25 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
num_batches = max(num_frames // frame_batch_size, 1)
|
||||
conv_cache = None
|
||||
dec = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.use_framewise_batching:
|
||||
dec = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
dec = torch.cat(dec, dim=2)
|
||||
else:
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
dec = torch.cat(dec, dim=2)
|
||||
dec = self.post_quant_conv(z)
|
||||
dec, _ = self.decoder(z, conv_cache=conv_cache)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
Reference in New Issue
Block a user