Compare commits

...

2 Commits

Author SHA1 Message Date
patil-suraj
a44e825b18 add code 2024-01-04 16:22:47 +05:30
patil-suraj
9eceb07f8d add second loop 2024-01-04 14:53:33 +05:30

View File

@@ -309,6 +309,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
num_inference_steps: int = 25,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
rec_guidance_scale: float = 1.0,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: int = 0.02,
@@ -534,6 +535,136 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
else:
frames = latents
# 3. Encode input image
next_image = frames[0][-1]
next_image_embeddings = self._encode_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
# 4. Encode input image using VAE
next_image = self.image_processor.preprocess(next_image, height=height, width=width)
noise = randn_tensor(next_image.shape, generator=generator, device=next_image.device, dtype=next_image.dtype)
next_image = next_image + noise_aug_strength * noise
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
next_image_latents = self._encode_vae_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
next_image_latents = next_image_latents.to(next_image_embeddings.dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Repeat the image latents for each frame so we can concatenate them with the noise
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
next_image_latents = next_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
next_latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height,
width,
next_image_embeddings.dtype,
device,
generator,
None,
)
image_embeddings = image_embeddings.chunk(2)[1]
image_latents = image_latents.chunk(2)[1]
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
next_latent_model_input = torch.cat([next_latents] * 2) if self.do_classifier_free_guidance else next_latents
next_latent_model_input = self.scheduler.scale_model_input(next_latent_model_input, t)
# Concatenate image_latents over channels dimention
next_latent_model_input = torch.cat([next_latent_model_input, next_image_latents], dim=2)
# predict the noise residual
noise_pred = self.unet(
next_latent_model_input,
t,
encoder_hidden_states=next_image_embeddings,
added_time_ids=added_time_ids,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
with torch.enable_grad():
self.unet.train()
self.unet.enable_gradient_checkpointing()
self.unet.requires_grad_(True)
latents.requires_grad_(True)
# Add noise to the latents
noise = torch.rand_like(latents.flatten(0, 1))
timestep = torch.ones(noise.shape[0]).to(noise.device) * t
prev_noised_latents = self.scheduler.add_noise(latents.flatten(0, 1), noise, timestep)
# [batch*frames, channels, height, width] -> [batch, frames, channels, height, width]
prev_noised_latents = prev_noised_latents.reshape(-1, num_frames, *prev_noised_latents.shape[1:])
scaled_prev_noised_latents = self.scheduler.scale_model_input(prev_noised_latents, t)
scaled_prev_noised_latents = torch.cat([scaled_prev_noised_latents, image_latents], dim=2)
rec_noise_pred = self.unet(
scaled_prev_noised_latents,
t,
encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids.chunk(2)[1],
return_dict=False,
)[0]
sigma = self.scheduler.sigmas[self.scheduler.step_index]
rec_prev_latents = rec_noise_pred * (-sigma / (sigma**2 + 1) ** 0.5) + (prev_noised_latents / (sigma**2 + 1))
loss = torch.nn.functional.mse_loss(rec_prev_latents, latents)
# compute grads
grads = torch.autograd.grad(loss, latents)[0]
self.unet.eval()
self.unet.requires_grad_(False)
latents.requires_grad_(False)
noise_pred = noise_pred - grads * rec_guidance_scale
# compute the previous noisy sample x_t -> x_t-1
next_latents = self.scheduler.step(noise_pred, t, next_latents).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
next_latents = callback_outputs.pop("latents", next_latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
next_frames = self.decode_latents(next_latents, num_frames, decode_chunk_size)
next_frames = tensor2vid(next_frames, self.image_processor, output_type=output_type)
frames = frames + next_frames
self.maybe_free_model_hooks()