Compare commits

...

2 Commits

Author SHA1 Message Date
Aryan
98fbe55419 update 2024-08-27 12:07:44 +02:00
Dhruv Nair
fef07128a8 update 2024-08-27 08:16:13 +00:00
3 changed files with 41 additions and 14 deletions

View File

@@ -691,7 +691,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0)
# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape

View File

@@ -38,6 +38,7 @@ from ...utils import (
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -127,6 +128,7 @@ class AnimateDiffSparseControlNetPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
@@ -448,15 +450,21 @@ class AnimateDiffSparseControlNetPipeline(
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, decode_chunk_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], decode_chunk_size):
batch_latents = latents[i : i + decode_chunk_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
@@ -611,10 +619,22 @@ class AnimateDiffSparseControlNetPipeline(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
num_channels_latents,
@@ -622,11 +642,6 @@ class AnimateDiffSparseControlNetPipeline(
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -728,6 +743,7 @@ class AnimateDiffSparseControlNetPipeline(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_chunk_size: int = 16,
):
r"""
The call function to the pipeline for generation.
@@ -806,6 +822,8 @@ class AnimateDiffSparseControlNetPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
decode_chunk_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
@@ -996,7 +1014,7 @@ class AnimateDiffSparseControlNetPipeline(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, decode_chunk_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 12. Offload all models

View File

@@ -221,15 +221,25 @@ class AnimateDiffFreeNoiseMixin:
self._free_noise_noise_type = noise_type
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
if hasattr(self, "controlnet"):
blocks.extend([*self.controlnet.down_blocks, self.controlnet.mid_block])
for block in blocks:
self._enable_free_noise_in_block(block)
if hasattr(block, "motion_modules"):
print(block.__class__.__name__)
self._enable_free_noise_in_block(block)
def disable_free_noise(self) -> None:
self._free_noise_context_length = None
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
if hasattr(self, "controlnet"):
blocks.extend([*self.controlnet.down_blocks, self.controlnet.mid_block])
for block in blocks:
self._disable_free_noise_in_block(block)
if hasattr(block, "motion_modules"):
print(block.__class__.__name__)
self._disable_free_noise_in_block(block)
@property
def free_noise_enabled(self):