mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
3 Commits
custom-cod
...
add_svd_in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a25ea49ed9 | ||
|
|
74d7e20a5e | ||
|
|
18e94e121f |
@@ -21,7 +21,7 @@ import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import BaseOutput, logging
|
||||
@@ -111,6 +111,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
@@ -296,10 +299,38 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def prepare_mask_latents(
|
||||
self, mask, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
return mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
||||
image: PipelineImageInput,
|
||||
mask_image: PipelineImageInput = None,
|
||||
height: int = 576,
|
||||
width: int = 1024,
|
||||
num_frames: Optional[int] = None,
|
||||
@@ -427,7 +458,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# 4. Encode input image using VAE
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
||||
image = image.to(self.device)
|
||||
init_image = image
|
||||
|
||||
noise = randn_tensor(image.shape, generator=generator, device=self.device, dtype=image.dtype)
|
||||
image = image + noise_aug_strength * noise
|
||||
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
@@ -437,6 +471,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
||||
image_latents = image_latents.to(image_embeddings.dtype)
|
||||
|
||||
init_image_latents = self._encode_vae_image(init_image, device, num_videos_per_prompt, False)
|
||||
init_image_latents = init_image_latents.to(image_embeddings.dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
@@ -444,6 +481,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
# Repeat the image latents for each frame so we can concatenate them with the noise
|
||||
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
||||
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
||||
init_image_latents = init_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
||||
|
||||
# for inpainting
|
||||
mask_noise = randn_tensor(init_image_latents.shape, generator=generator, device=self.device, dtype=image_latents.dtype)
|
||||
|
||||
# 5. Get Added Time IDs
|
||||
added_time_ids = self._get_add_time_ids(
|
||||
@@ -483,6 +524,20 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
mask = self.prepare_mask_latents(
|
||||
mask_condition,
|
||||
batch_size * num_videos_per_prompt,
|
||||
height,
|
||||
width,
|
||||
latents.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
@@ -512,6 +567,22 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
init_mask = init_mask[None, :]
|
||||
mask_noise = mask_noise[:1]
|
||||
else:
|
||||
init_mask = mask[None, :]
|
||||
|
||||
# let's make sure init latents correspond all to the first frame
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_image_latents, mask_noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
|
||||
Reference in New Issue
Block a user