mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 19:04:49 +08:00
Compare commits
2 Commits
pipeline-s
...
rec-guided
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a44e825b18 | ||
|
|
9eceb07f8d |
@@ -309,6 +309,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
|||||||
num_inference_steps: int = 25,
|
num_inference_steps: int = 25,
|
||||||
min_guidance_scale: float = 1.0,
|
min_guidance_scale: float = 1.0,
|
||||||
max_guidance_scale: float = 3.0,
|
max_guidance_scale: float = 3.0,
|
||||||
|
rec_guidance_scale: float = 1.0,
|
||||||
fps: int = 7,
|
fps: int = 7,
|
||||||
motion_bucket_id: int = 127,
|
motion_bucket_id: int = 127,
|
||||||
noise_aug_strength: int = 0.02,
|
noise_aug_strength: int = 0.02,
|
||||||
@@ -535,6 +536,136 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
frames = latents
|
frames = latents
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Encode input image
|
||||||
|
next_image = frames[0][-1]
|
||||||
|
next_image_embeddings = self._encode_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||||
|
|
||||||
|
# 4. Encode input image using VAE
|
||||||
|
next_image = self.image_processor.preprocess(next_image, height=height, width=width)
|
||||||
|
noise = randn_tensor(next_image.shape, generator=generator, device=next_image.device, dtype=next_image.dtype)
|
||||||
|
next_image = next_image + noise_aug_strength * noise
|
||||||
|
|
||||||
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
next_image_latents = self._encode_vae_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||||
|
next_image_latents = next_image_latents.to(next_image_embeddings.dtype)
|
||||||
|
|
||||||
|
# cast back to fp16 if needed
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
# Repeat the image latents for each frame so we can concatenate them with the noise
|
||||||
|
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
||||||
|
next_image_latents = next_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.unet.config.in_channels
|
||||||
|
next_latents = self.prepare_latents(
|
||||||
|
batch_size * num_videos_per_prompt,
|
||||||
|
num_frames,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
next_image_embeddings.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeddings = image_embeddings.chunk(2)[1]
|
||||||
|
image_latents = image_latents.chunk(2)[1]
|
||||||
|
|
||||||
|
# 8. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
next_latent_model_input = torch.cat([next_latents] * 2) if self.do_classifier_free_guidance else next_latents
|
||||||
|
next_latent_model_input = self.scheduler.scale_model_input(next_latent_model_input, t)
|
||||||
|
|
||||||
|
# Concatenate image_latents over channels dimention
|
||||||
|
next_latent_model_input = torch.cat([next_latent_model_input, next_image_latents], dim=2)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
next_latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=next_image_embeddings,
|
||||||
|
added_time_ids=added_time_ids,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if self.do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
self.unet.train()
|
||||||
|
self.unet.enable_gradient_checkpointing()
|
||||||
|
self.unet.requires_grad_(True)
|
||||||
|
latents.requires_grad_(True)
|
||||||
|
|
||||||
|
# Add noise to the latents
|
||||||
|
noise = torch.rand_like(latents.flatten(0, 1))
|
||||||
|
timestep = torch.ones(noise.shape[0]).to(noise.device) * t
|
||||||
|
prev_noised_latents = self.scheduler.add_noise(latents.flatten(0, 1), noise, timestep)
|
||||||
|
# [batch*frames, channels, height, width] -> [batch, frames, channels, height, width]
|
||||||
|
prev_noised_latents = prev_noised_latents.reshape(-1, num_frames, *prev_noised_latents.shape[1:])
|
||||||
|
scaled_prev_noised_latents = self.scheduler.scale_model_input(prev_noised_latents, t)
|
||||||
|
scaled_prev_noised_latents = torch.cat([scaled_prev_noised_latents, image_latents], dim=2)
|
||||||
|
|
||||||
|
rec_noise_pred = self.unet(
|
||||||
|
scaled_prev_noised_latents,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=image_embeddings,
|
||||||
|
added_time_ids=added_time_ids.chunk(2)[1],
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
sigma = self.scheduler.sigmas[self.scheduler.step_index]
|
||||||
|
rec_prev_latents = rec_noise_pred * (-sigma / (sigma**2 + 1) ** 0.5) + (prev_noised_latents / (sigma**2 + 1))
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(rec_prev_latents, latents)
|
||||||
|
# compute grads
|
||||||
|
grads = torch.autograd.grad(loss, latents)[0]
|
||||||
|
|
||||||
|
self.unet.eval()
|
||||||
|
self.unet.requires_grad_(False)
|
||||||
|
latents.requires_grad_(False)
|
||||||
|
|
||||||
|
noise_pred = noise_pred - grads * rec_guidance_scale
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
next_latents = self.scheduler.step(noise_pred, t, next_latents).prev_sample
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
next_latents = callback_outputs.pop("latents", next_latents)
|
||||||
|
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
# cast back to fp16 if needed
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float16)
|
||||||
|
next_frames = self.decode_latents(next_latents, num_frames, decode_chunk_size)
|
||||||
|
next_frames = tensor2vid(next_frames, self.image_processor, output_type=output_type)
|
||||||
|
|
||||||
|
frames = frames + next_frames
|
||||||
|
|
||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user