Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick von Platen
434dab4a2f make style 2023-12-18 13:15:00 +00:00
Patrick von Platen
dbcbfb3118 [SVD] Fix guidance scale 2023-11-30 16:17:53 +00:00

View File

@@ -290,7 +290,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
if isinstance(self.guidance_scale, (int, float)):
return self.guidance_scale
return self.guidance_scale.max() > 1
@property
def num_timesteps(self):
@@ -415,10 +417,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = max_guidance_scale > 1.0
self._guidance_scale = max_guidance_scale
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
# is why it is reduced here.
@@ -434,7 +436,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
if needs_upcasting:
self.vae.to(dtype=torch.float32)
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
image_latents = image_latents.to(image_embeddings.dtype)
# cast back to fp16 if needed
@@ -453,7 +455,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
image_embeddings.dtype,
batch_size,
num_videos_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
)
added_time_ids = added_time_ids.to(device)
@@ -489,7 +491,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Concatenate image_latents over channels dimention
@@ -505,7 +507,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
)[0]
# perform guidance
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)