Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
f57d6ace0a remove guidance_scale 2024-01-06 14:58:06 +05:30
sayakpaul
9af7cce713 print important shapes 2024-01-06 14:54:23 +05:30

View File

@@ -422,7 +422,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
print(f"image embeddings: {image_embeddings.shape}")
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
# is why it is reduced here.
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
@@ -439,6 +439,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
image_latents = image_latents.to(image_embeddings.dtype)
print(f"Image latents: {image_latents.shape}")
# cast back to fp16 if needed
if needs_upcasting:
@@ -447,6 +448,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# Repeat the image latents for each frame so we can concatenate them with the noise
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
print(f"Image latents after reshape: {image_latents.shape}")
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
@@ -459,6 +461,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
self.do_classifier_free_guidance,
)
added_time_ids = added_time_ids.to(device)
print(f"added_time_ids: {added_time_ids.shape}")
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -477,6 +480,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
generator,
latents,
)
print(f"latents: {latents.shape}")
# 7. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
@@ -494,9 +498,11 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
print(f"initial latent_model_input: {latent_model_input.shape}")
# Concatenate image_latents over channels dimention
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
print(f"concatenated latent_model_input: {latent_model_input.shape}")
# predict the noise residual
noise_pred = self.unet(
@@ -506,6 +512,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
added_time_ids=added_time_ids,
return_dict=False,
)[0]
print(f"noise_pred: {noise_pred.shape}")
# perform guidance
if self.do_classifier_free_guidance: