Compare commits

...

3 Commits

Author SHA1 Message Date
Patrick von Platen
a25ea49ed9 [SVD] Fix guidance scale 2023-11-30 16:59:30 +00:00
Patrick von Platen
74d7e20a5e [SVD] Fix guidance scale 2023-11-30 16:58:55 +00:00
Patrick von Platen
18e94e121f add text encoder 2023-11-30 16:08:47 +00:00

View File

@@ -21,7 +21,7 @@ import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging
@@ -111,6 +111,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype
@@ -296,10 +299,38 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
def prepare_mask_latents(
self, mask, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
# aligning device to prevent device errors when concating it with the latent model input
return mask
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
image: PipelineImageInput,
mask_image: PipelineImageInput = None,
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
@@ -427,7 +458,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# 4. Encode input image using VAE
image = self.image_processor.preprocess(image, height=height, width=width)
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
image = image.to(self.device)
init_image = image
noise = randn_tensor(image.shape, generator=generator, device=self.device, dtype=image.dtype)
image = image + noise_aug_strength * noise
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
@@ -437,6 +471,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
image_latents = image_latents.to(image_embeddings.dtype)
init_image_latents = self._encode_vae_image(init_image, device, num_videos_per_prompt, False)
init_image_latents = init_image_latents.to(image_embeddings.dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
@@ -444,6 +481,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# Repeat the image latents for each frame so we can concatenate them with the noise
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
init_image_latents = init_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
# for inpainting
mask_noise = randn_tensor(init_image_latents.shape, generator=generator, device=self.device, dtype=image_latents.dtype)
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
@@ -483,6 +524,20 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale
# 7. Prepare mask latent variables
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
mask = self.prepare_mask_latents(
mask_condition,
batch_size * num_videos_per_prompt,
height,
width,
latents.dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -512,6 +567,22 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
if do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
init_mask = init_mask[None, :]
mask_noise = mask_noise[:1]
else:
init_mask = mask[None, :]
# let's make sure init latents correspond all to the first frame
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise(
init_image_latents, mask_noise, torch.tensor([noise_timestep])
)
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs: