mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
token-drop
...
rec-guided
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a44e825b18 | ||
|
|
9eceb07f8d |
@@ -309,6 +309,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
num_inference_steps: int = 25,
|
||||
min_guidance_scale: float = 1.0,
|
||||
max_guidance_scale: float = 3.0,
|
||||
rec_guidance_scale: float = 1.0,
|
||||
fps: int = 7,
|
||||
motion_bucket_id: int = 127,
|
||||
noise_aug_strength: int = 0.02,
|
||||
@@ -534,6 +535,136 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
||||
else:
|
||||
frames = latents
|
||||
|
||||
|
||||
# 3. Encode input image
|
||||
next_image = frames[0][-1]
|
||||
next_image_embeddings = self._encode_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||
|
||||
# 4. Encode input image using VAE
|
||||
next_image = self.image_processor.preprocess(next_image, height=height, width=width)
|
||||
noise = randn_tensor(next_image.shape, generator=generator, device=next_image.device, dtype=next_image.dtype)
|
||||
next_image = next_image + noise_aug_strength * noise
|
||||
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
next_image_latents = self._encode_vae_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||
next_image_latents = next_image_latents.to(next_image_embeddings.dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
# Repeat the image latents for each frame so we can concatenate them with the noise
|
||||
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
||||
next_image_latents = next_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
next_latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_frames,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
next_image_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
None,
|
||||
)
|
||||
|
||||
image_embeddings = image_embeddings.chunk(2)[1]
|
||||
image_latents = image_latents.chunk(2)[1]
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
next_latent_model_input = torch.cat([next_latents] * 2) if self.do_classifier_free_guidance else next_latents
|
||||
next_latent_model_input = self.scheduler.scale_model_input(next_latent_model_input, t)
|
||||
|
||||
# Concatenate image_latents over channels dimention
|
||||
next_latent_model_input = torch.cat([next_latent_model_input, next_image_latents], dim=2)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
next_latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=next_image_embeddings,
|
||||
added_time_ids=added_time_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
with torch.enable_grad():
|
||||
self.unet.train()
|
||||
self.unet.enable_gradient_checkpointing()
|
||||
self.unet.requires_grad_(True)
|
||||
latents.requires_grad_(True)
|
||||
|
||||
# Add noise to the latents
|
||||
noise = torch.rand_like(latents.flatten(0, 1))
|
||||
timestep = torch.ones(noise.shape[0]).to(noise.device) * t
|
||||
prev_noised_latents = self.scheduler.add_noise(latents.flatten(0, 1), noise, timestep)
|
||||
# [batch*frames, channels, height, width] -> [batch, frames, channels, height, width]
|
||||
prev_noised_latents = prev_noised_latents.reshape(-1, num_frames, *prev_noised_latents.shape[1:])
|
||||
scaled_prev_noised_latents = self.scheduler.scale_model_input(prev_noised_latents, t)
|
||||
scaled_prev_noised_latents = torch.cat([scaled_prev_noised_latents, image_latents], dim=2)
|
||||
|
||||
rec_noise_pred = self.unet(
|
||||
scaled_prev_noised_latents,
|
||||
t,
|
||||
encoder_hidden_states=image_embeddings,
|
||||
added_time_ids=added_time_ids.chunk(2)[1],
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
sigma = self.scheduler.sigmas[self.scheduler.step_index]
|
||||
rec_prev_latents = rec_noise_pred * (-sigma / (sigma**2 + 1) ** 0.5) + (prev_noised_latents / (sigma**2 + 1))
|
||||
|
||||
loss = torch.nn.functional.mse_loss(rec_prev_latents, latents)
|
||||
# compute grads
|
||||
grads = torch.autograd.grad(loss, latents)[0]
|
||||
|
||||
self.unet.eval()
|
||||
self.unet.requires_grad_(False)
|
||||
latents.requires_grad_(False)
|
||||
|
||||
noise_pred = noise_pred - grads * rec_guidance_scale
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
next_latents = self.scheduler.step(noise_pred, t, next_latents).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
next_latents = callback_outputs.pop("latents", next_latents)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
next_frames = self.decode_latents(next_latents, num_frames, decode_chunk_size)
|
||||
next_frames = tensor2vid(next_frames, self.image_processor, output_type=output_type)
|
||||
|
||||
frames = frames + next_frames
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user