|
|
|
|
@@ -17,9 +17,6 @@ from typing import Callable, Dict, List, Optional, Union
|
|
|
|
|
import numpy as np
|
|
|
|
|
import PIL.Image
|
|
|
|
|
import torch
|
|
|
|
|
import torchvision
|
|
|
|
|
import torchvision.transforms
|
|
|
|
|
import torchvision.transforms.functional
|
|
|
|
|
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
|
|
|
|
|
|
|
|
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
|
|
|
@@ -54,11 +51,13 @@ else:
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
|
|
|
|
|
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
|
|
|
|
|
n_pad_frames = num_frames - video.shape[2]
|
|
|
|
|
if n_pad_frames > 0:
|
|
|
|
|
last_frame = video[:, :, -1:, :, :]
|
|
|
|
|
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
|
|
|
|
|
elif num_frames < video.shape[2]:
|
|
|
|
|
video = video[:, :, :num_frames, :, :]
|
|
|
|
|
return video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -134,8 +133,8 @@ EXAMPLE_DOC_STRING = """
|
|
|
|
|
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
|
|
|
|
|
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
|
|
|
|
|
|
|
|
|
|
>>> # Transfer inference with controls.
|
|
|
|
|
>>> video = pipe(
|
|
|
|
|
... video=input_video[:num_frames],
|
|
|
|
|
... controls=controls,
|
|
|
|
|
... controls_conditioning_scale=1.0,
|
|
|
|
|
... prompt=prompt,
|
|
|
|
|
@@ -149,7 +148,7 @@ EXAMPLE_DOC_STRING = """
|
|
|
|
|
|
|
|
|
|
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
r"""
|
|
|
|
|
Pipeline for Cosmos Transfer2.5 base model.
|
|
|
|
|
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
|
|
|
|
|
|
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
|
|
|
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
|
|
|
|
@@ -166,12 +165,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
|
|
|
|
vae ([`AutoencoderKLWan`]):
|
|
|
|
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
|
|
|
|
controlnet ([`CosmosControlNetModel`]):
|
|
|
|
|
ControlNet used to condition generation on control inputs.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
|
|
|
|
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
|
|
|
|
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
|
|
|
|
_optional_components = ["safety_checker", "controlnet"]
|
|
|
|
|
_optional_components = ["safety_checker"]
|
|
|
|
|
_exclude_from_cpu_offload = ["safety_checker"]
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
@@ -181,8 +182,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
transformer: CosmosTransformer3DModel,
|
|
|
|
|
vae: AutoencoderKLWan,
|
|
|
|
|
scheduler: UniPCMultistepScheduler,
|
|
|
|
|
controlnet: Optional[CosmosControlNetModel],
|
|
|
|
|
safety_checker: CosmosSafetyChecker = None,
|
|
|
|
|
controlnet: CosmosControlNetModel,
|
|
|
|
|
safety_checker: Optional[CosmosSafetyChecker] = None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
@@ -384,10 +385,11 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
num_frames_in: int = 93,
|
|
|
|
|
num_frames_out: int = 93,
|
|
|
|
|
do_classifier_free_guidance: bool = True,
|
|
|
|
|
dtype: torch.dtype | None = None,
|
|
|
|
|
device: torch.device | None = None,
|
|
|
|
|
generator: torch.Generator | list[torch.Generator] | None = None,
|
|
|
|
|
latents: torch.Tensor | None = None,
|
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
|
|
latents: Optional[torch.Tensor] = None,
|
|
|
|
|
num_cond_latent_frames: int = 0,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if isinstance(generator, list) and len(generator) != batch_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
@@ -402,10 +404,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
W = width // self.vae_scale_factor_spatial
|
|
|
|
|
shape = (B, C, T, H, W)
|
|
|
|
|
|
|
|
|
|
if num_frames_in == 0:
|
|
|
|
|
if latents is None:
|
|
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
if latents is not None:
|
|
|
|
|
if latents.shape[1:] != shape[1:]:
|
|
|
|
|
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
|
|
|
|
|
latents = latents.to(device=device, dtype=dtype)
|
|
|
|
|
else:
|
|
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
if num_frames_in == 0:
|
|
|
|
|
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
|
|
|
|
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
|
|
|
|
|
|
|
|
|
@@ -435,16 +441,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
|
|
|
|
cond_latents = (cond_latents - latents_mean) / latents_std
|
|
|
|
|
|
|
|
|
|
if latents is None:
|
|
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
else:
|
|
|
|
|
latents = latents.to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
padding_shape = (B, 1, T, H, W)
|
|
|
|
|
ones_padding = latents.new_ones(padding_shape)
|
|
|
|
|
zeros_padding = latents.new_zeros(padding_shape)
|
|
|
|
|
|
|
|
|
|
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
|
|
|
|
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
|
|
|
|
|
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
|
|
|
|
|
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
@@ -454,34 +456,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
cond_indicator,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _encode_controls(
|
|
|
|
|
self,
|
|
|
|
|
controls: Optional[torch.Tensor],
|
|
|
|
|
height: int,
|
|
|
|
|
width: int,
|
|
|
|
|
num_frames: int,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
generator: torch.Generator | list[torch.Generator] | None,
|
|
|
|
|
) -> Optional[torch.Tensor]:
|
|
|
|
|
if controls is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
control_video = self.video_processor.preprocess_video(controls, height, width)
|
|
|
|
|
control_video = _maybe_pad_video(control_video, num_frames)
|
|
|
|
|
|
|
|
|
|
control_video = control_video.to(device=device, dtype=self.vae.dtype)
|
|
|
|
|
control_latents = [
|
|
|
|
|
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
|
|
|
|
|
]
|
|
|
|
|
control_latents = torch.cat(control_latents, dim=0).to(dtype)
|
|
|
|
|
|
|
|
|
|
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
|
|
|
|
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
|
|
|
|
control_latents = (control_latents - latents_mean) / latents_std
|
|
|
|
|
return control_latents
|
|
|
|
|
|
|
|
|
|
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
|
|
|
|
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
|
|
|
|
def check_inputs(
|
|
|
|
|
self,
|
|
|
|
|
prompt,
|
|
|
|
|
@@ -489,9 +464,25 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
width,
|
|
|
|
|
prompt_embeds=None,
|
|
|
|
|
callback_on_step_end_tensor_inputs=None,
|
|
|
|
|
num_ar_conditional_frames=None,
|
|
|
|
|
num_ar_latent_conditional_frames=None,
|
|
|
|
|
num_frames_per_chunk=None,
|
|
|
|
|
num_frames=None,
|
|
|
|
|
conditional_frame_timestep=0.1,
|
|
|
|
|
):
|
|
|
|
|
if height % 16 != 0 or width % 16 != 0:
|
|
|
|
|
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
|
|
|
|
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if num_frames is not None and num_frames <= 0:
|
|
|
|
|
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
|
|
|
|
|
|
|
|
|
|
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
|
|
|
|
|
f"{conditional_frame_timestep}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
|
|
|
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
|
|
|
|
@@ -512,6 +503,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
|
|
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
|
|
|
|
|
|
|
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
|
|
|
|
|
)
|
|
|
|
|
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
|
|
|
|
|
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
|
|
|
|
|
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
|
|
|
|
|
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
|
|
|
|
|
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
|
|
|
|
|
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
|
|
|
|
|
|
|
|
|
|
if num_ar_latent_conditional_frames is not None:
|
|
|
|
|
num_ar_conditional_frames = max(
|
|
|
|
|
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
min_chunk_len = self.vae_scale_factor_temporal + 1
|
|
|
|
|
if num_frames_per_chunk < min_chunk_len:
|
|
|
|
|
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
|
|
|
|
|
num_frames_per_chunk = min_chunk_len
|
|
|
|
|
|
|
|
|
|
max_frames_by_rope = None
|
|
|
|
|
if getattr(self.transformer.config, "max_size", None) is not None:
|
|
|
|
|
max_frames_by_rope = max(
|
|
|
|
|
size // patch
|
|
|
|
|
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
|
|
|
|
|
)
|
|
|
|
|
if num_frames_per_chunk > max_frames_by_rope:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
|
|
|
|
|
"Please reduce `num_frames_per_chunk`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if num_ar_conditional_frames >= num_frames_per_chunk:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return num_frames_per_chunk
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def guidance_scale(self):
|
|
|
|
|
return self._guidance_scale
|
|
|
|
|
@@ -536,23 +567,22 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
|
|
def __call__(
|
|
|
|
|
self,
|
|
|
|
|
image: PipelineImageInput | None = None,
|
|
|
|
|
video: List[PipelineImageInput] | None = None,
|
|
|
|
|
controls: PipelineImageInput | List[PipelineImageInput],
|
|
|
|
|
controls_conditioning_scale: Union[float, List[float]] = 1.0,
|
|
|
|
|
prompt: Union[str, List[str]] | None = None,
|
|
|
|
|
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
|
|
|
|
|
height: int = 704,
|
|
|
|
|
width: int | None = None,
|
|
|
|
|
num_frames: int = 93,
|
|
|
|
|
width: Optional[int] = None,
|
|
|
|
|
num_frames: Optional[int] = None,
|
|
|
|
|
num_frames_per_chunk: int = 93,
|
|
|
|
|
num_inference_steps: int = 36,
|
|
|
|
|
guidance_scale: float = 3.0,
|
|
|
|
|
num_videos_per_prompt: Optional[int] = 1,
|
|
|
|
|
generator: torch.Generator | list[torch.Generator] | None = None,
|
|
|
|
|
latents: torch.Tensor | None = None,
|
|
|
|
|
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
|
|
|
|
|
controls_conditioning_scale: float | list[float] = 1.0,
|
|
|
|
|
prompt_embeds: torch.Tensor | None = None,
|
|
|
|
|
negative_prompt_embeds: torch.Tensor | None = None,
|
|
|
|
|
output_type: str = "pil",
|
|
|
|
|
num_videos_per_prompt: int = 1,
|
|
|
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
|
|
latents: Optional[torch.Tensor] = None,
|
|
|
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
|
|
output_type: Optional[str] = "pil",
|
|
|
|
|
return_dict: bool = True,
|
|
|
|
|
callback_on_step_end: Optional[
|
|
|
|
|
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
|
|
|
|
@@ -560,24 +590,26 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
|
|
|
max_sequence_length: int = 512,
|
|
|
|
|
conditional_frame_timestep: float = 0.1,
|
|
|
|
|
num_ar_conditional_frames: Optional[int] = 1,
|
|
|
|
|
num_ar_latent_conditional_frames: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
The call function to the pipeline for generation. Supports three modes:
|
|
|
|
|
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
|
|
|
|
|
are pre-computed.
|
|
|
|
|
|
|
|
|
|
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
|
|
|
|
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
|
|
|
|
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
|
|
|
|
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
|
|
|
|
|
(default) then the number of output frames will match the input `controls`.
|
|
|
|
|
|
|
|
|
|
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
|
|
|
|
above in "*2Image mode").
|
|
|
|
|
|
|
|
|
|
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
|
|
|
|
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
|
|
|
|
|
denoising loop. In addition, when auto-regressive inference is performed, the previous
|
|
|
|
|
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
|
|
|
|
|
inference loops.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
|
|
|
|
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
|
|
|
|
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
|
|
|
|
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
|
|
|
|
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
|
|
|
|
|
Control image or video input used by the ControlNet.
|
|
|
|
|
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
|
|
|
|
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
|
|
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
|
|
|
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
|
|
|
|
height (`int`, defaults to `704`):
|
|
|
|
|
@@ -585,9 +617,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
width (`int`, *optional*):
|
|
|
|
|
The width in pixels of the generated image. If not provided, this will be determined based on the
|
|
|
|
|
aspect ratio of the input and the provided height.
|
|
|
|
|
num_frames (`int`, defaults to `93`):
|
|
|
|
|
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
|
|
|
|
num_inference_steps (`int`, defaults to `35`):
|
|
|
|
|
num_frames (`int`, *optional*):
|
|
|
|
|
Number of output frames. Defaults to `None` to output the same number of frames as the input
|
|
|
|
|
`controls`.
|
|
|
|
|
num_inference_steps (`int`, defaults to `36`):
|
|
|
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
|
|
|
expense of slower inference.
|
|
|
|
|
guidance_scale (`float`, defaults to `3.0`):
|
|
|
|
|
@@ -601,13 +634,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
|
|
|
|
generation deterministic.
|
|
|
|
|
latents (`torch.Tensor`, *optional*):
|
|
|
|
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
|
|
|
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
|
|
|
tensor is generated by sampling using the supplied random `generator`.
|
|
|
|
|
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
|
|
|
|
|
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
|
|
|
|
|
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
|
|
|
|
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
|
|
|
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
|
|
|
|
|
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
|
|
|
|
|
sampling using the supplied random `generator`.
|
|
|
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
|
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
|
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
|
|
|
@@ -630,7 +659,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
max_sequence_length (`int`, defaults to `512`):
|
|
|
|
|
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
|
|
|
|
the prompt is shorter than this length, it will be padded.
|
|
|
|
|
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
|
|
|
|
|
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
|
|
|
|
|
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
|
|
|
|
|
|
|
|
|
|
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
|
|
|
|
controls is > num_frames_per_chunk
|
|
|
|
|
num_ar_latent_conditional_frames (`int`, *optional*):
|
|
|
|
|
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
|
|
|
|
|
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
|
|
|
|
|
|
|
|
|
|
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
|
|
|
|
controls is > num_frames_per_chunk
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
@@ -650,21 +690,40 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
|
|
|
|
|
|
|
|
|
if width is None:
|
|
|
|
|
frame = image or video[0] if image or video else None
|
|
|
|
|
if frame is None and controls is not None:
|
|
|
|
|
frame = controls[0] if isinstance(controls, list) else controls
|
|
|
|
|
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
|
|
|
|
|
frame = controls[0]
|
|
|
|
|
frame = controls[0] if isinstance(controls, list) else controls
|
|
|
|
|
if isinstance(frame, list):
|
|
|
|
|
frame = frame[0]
|
|
|
|
|
if isinstance(frame, (torch.Tensor, np.ndarray)):
|
|
|
|
|
if frame.ndim == 5:
|
|
|
|
|
frame = frame[0, 0]
|
|
|
|
|
elif frame.ndim == 4:
|
|
|
|
|
frame = frame[0]
|
|
|
|
|
|
|
|
|
|
if frame is None:
|
|
|
|
|
width = int((height + 16) * (1280 / 720))
|
|
|
|
|
elif isinstance(frame, PIL.Image.Image):
|
|
|
|
|
if isinstance(frame, PIL.Image.Image):
|
|
|
|
|
width = int((height + 16) * (frame.width / frame.height))
|
|
|
|
|
else:
|
|
|
|
|
if frame.ndim != 3:
|
|
|
|
|
raise ValueError("`controls` must contain 3D frames in CHW format.")
|
|
|
|
|
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
|
|
|
|
|
|
|
|
|
|
# Check inputs. Raise error if not correct
|
|
|
|
|
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
|
|
|
|
num_frames_per_chunk = self.check_inputs(
|
|
|
|
|
prompt,
|
|
|
|
|
height,
|
|
|
|
|
width,
|
|
|
|
|
prompt_embeds,
|
|
|
|
|
callback_on_step_end_tensor_inputs,
|
|
|
|
|
num_ar_conditional_frames,
|
|
|
|
|
num_ar_latent_conditional_frames,
|
|
|
|
|
num_frames_per_chunk,
|
|
|
|
|
num_frames,
|
|
|
|
|
conditional_frame_timestep,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if num_ar_latent_conditional_frames is not None:
|
|
|
|
|
num_cond_latent_frames = num_ar_latent_conditional_frames
|
|
|
|
|
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
|
|
|
|
|
else:
|
|
|
|
|
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
|
|
|
|
|
|
|
|
|
|
self._guidance_scale = guidance_scale
|
|
|
|
|
self._current_timestep = None
|
|
|
|
|
@@ -709,102 +768,137 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
vae_dtype = self.vae.dtype
|
|
|
|
|
transformer_dtype = self.transformer.dtype
|
|
|
|
|
|
|
|
|
|
img_context = torch.zeros(
|
|
|
|
|
batch_size,
|
|
|
|
|
self.transformer.config.img_context_num_tokens,
|
|
|
|
|
self.transformer.config.img_context_dim_in,
|
|
|
|
|
device=prompt_embeds.device,
|
|
|
|
|
dtype=transformer_dtype,
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states = (prompt_embeds, img_context)
|
|
|
|
|
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
|
|
|
|
|
|
|
|
|
num_frames_in = None
|
|
|
|
|
if image is not None:
|
|
|
|
|
if batch_size != 1:
|
|
|
|
|
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
|
|
|
|
|
|
|
|
|
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
|
|
|
|
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
|
|
|
|
video = video.unsqueeze(0)
|
|
|
|
|
num_frames_in = 1
|
|
|
|
|
elif video is None:
|
|
|
|
|
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
|
|
|
|
num_frames_in = 0
|
|
|
|
|
else:
|
|
|
|
|
num_frames_in = len(video)
|
|
|
|
|
|
|
|
|
|
if batch_size != 1:
|
|
|
|
|
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
|
|
|
|
|
|
|
|
|
assert video is not None
|
|
|
|
|
video = self.video_processor.preprocess_video(video, height, width)
|
|
|
|
|
|
|
|
|
|
# pad with last frame (for video2world)
|
|
|
|
|
num_frames_out = num_frames
|
|
|
|
|
video = _maybe_pad_video(video, num_frames_out)
|
|
|
|
|
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
|
|
|
|
|
|
|
|
|
video = video.to(device=device, dtype=vae_dtype)
|
|
|
|
|
|
|
|
|
|
num_channels_latents = self.transformer.config.in_channels - 1
|
|
|
|
|
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
|
|
|
|
video=video,
|
|
|
|
|
batch_size=batch_size * num_videos_per_prompt,
|
|
|
|
|
num_channels_latents=num_channels_latents,
|
|
|
|
|
height=height,
|
|
|
|
|
width=width,
|
|
|
|
|
num_frames_in=num_frames_in,
|
|
|
|
|
num_frames_out=num_frames,
|
|
|
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device=device,
|
|
|
|
|
generator=generator,
|
|
|
|
|
latents=latents,
|
|
|
|
|
)
|
|
|
|
|
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
|
|
|
|
cond_mask = cond_mask.to(transformer_dtype)
|
|
|
|
|
|
|
|
|
|
controls_latents = None
|
|
|
|
|
if controls is not None:
|
|
|
|
|
controls_latents = self._encode_controls(
|
|
|
|
|
controls,
|
|
|
|
|
height=height,
|
|
|
|
|
width=width,
|
|
|
|
|
num_frames=num_frames,
|
|
|
|
|
if getattr(self.transformer.config, "img_context_dim_in", None):
|
|
|
|
|
img_context = torch.zeros(
|
|
|
|
|
batch_size,
|
|
|
|
|
self.transformer.config.img_context_num_tokens,
|
|
|
|
|
self.transformer.config.img_context_dim_in,
|
|
|
|
|
device=prompt_embeds.device,
|
|
|
|
|
dtype=transformer_dtype,
|
|
|
|
|
device=device,
|
|
|
|
|
generator=generator,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
|
|
|
|
if num_videos_per_prompt > 1:
|
|
|
|
|
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
|
|
|
|
|
|
|
|
|
|
# Denoising loop
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
|
|
|
timesteps = self.scheduler.timesteps
|
|
|
|
|
self._num_timesteps = len(timesteps)
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
|
|
|
encoder_hidden_states = (prompt_embeds, img_context)
|
|
|
|
|
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
|
|
|
|
else:
|
|
|
|
|
encoder_hidden_states = prompt_embeds
|
|
|
|
|
neg_encoder_hidden_states = negative_prompt_embeds
|
|
|
|
|
|
|
|
|
|
gt_velocity = (latents - cond_latent) * cond_mask
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
|
|
|
for i, t in enumerate(timesteps):
|
|
|
|
|
if self.interrupt:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
self._current_timestep = t.cpu().item()
|
|
|
|
|
|
|
|
|
|
# NOTE: assumes sigma(t) \in [0, 1]
|
|
|
|
|
sigma_t = (
|
|
|
|
|
torch.tensor(self.scheduler.sigmas[i].item())
|
|
|
|
|
.unsqueeze(0)
|
|
|
|
|
.to(device=device, dtype=transformer_dtype)
|
|
|
|
|
control_video = self.video_processor.preprocess_video(controls, height, width)
|
|
|
|
|
if control_video.shape[0] != batch_size:
|
|
|
|
|
if control_video.shape[0] == 1:
|
|
|
|
|
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
|
|
|
|
in_latents = in_latents.to(transformer_dtype)
|
|
|
|
|
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
|
|
|
|
control_blocks = None
|
|
|
|
|
if controls_latents is not None and self.controlnet is not None:
|
|
|
|
|
num_frames_out = control_video.shape[2]
|
|
|
|
|
if num_frames is not None:
|
|
|
|
|
num_frames_out = min(num_frames_out, num_frames)
|
|
|
|
|
|
|
|
|
|
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
|
|
|
|
|
|
|
|
|
|
# chunk information
|
|
|
|
|
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
|
|
|
|
|
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
|
|
|
|
|
chunk_idxs = [
|
|
|
|
|
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
|
|
|
|
|
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
video_chunks = []
|
|
|
|
|
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
|
|
|
|
|
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
|
|
|
|
|
|
|
|
|
|
def decode_latents(latents):
|
|
|
|
|
latents = latents * latents_std + latents_mean
|
|
|
|
|
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
|
|
|
|
|
return video
|
|
|
|
|
|
|
|
|
|
latents_arg = latents
|
|
|
|
|
initial_num_cond_latent_frames = 0
|
|
|
|
|
latent_chunks = []
|
|
|
|
|
num_chunks = len(chunk_idxs)
|
|
|
|
|
total_steps = num_inference_steps * num_chunks
|
|
|
|
|
with self.progress_bar(total=total_steps) as progress_bar:
|
|
|
|
|
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
|
|
|
|
|
if chunk_idx == 0:
|
|
|
|
|
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
|
|
|
|
|
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
|
|
|
|
|
else:
|
|
|
|
|
prev_output = video_chunks[-1].clone()
|
|
|
|
|
if num_ar_conditional_frames > 0:
|
|
|
|
|
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
|
|
|
|
|
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
|
|
|
|
|
else:
|
|
|
|
|
prev_output.fill_(-1)
|
|
|
|
|
|
|
|
|
|
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
|
|
|
|
|
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
|
|
|
|
|
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
|
|
|
|
video=chunk_video,
|
|
|
|
|
batch_size=batch_size * num_videos_per_prompt,
|
|
|
|
|
num_channels_latents=self.transformer.config.in_channels - 1,
|
|
|
|
|
height=height,
|
|
|
|
|
width=width,
|
|
|
|
|
num_frames_in=chunk_video.shape[2],
|
|
|
|
|
num_frames_out=num_frames_per_chunk,
|
|
|
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device=device,
|
|
|
|
|
generator=generator,
|
|
|
|
|
num_cond_latent_frames=initial_num_cond_latent_frames
|
|
|
|
|
if chunk_idx == 0
|
|
|
|
|
else num_cond_latent_frames,
|
|
|
|
|
latents=latents_arg,
|
|
|
|
|
)
|
|
|
|
|
cond_mask = cond_mask.to(transformer_dtype)
|
|
|
|
|
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
|
|
|
|
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
|
|
|
|
|
|
|
|
|
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
|
|
|
|
|
device=device, dtype=self.vae.dtype
|
|
|
|
|
)
|
|
|
|
|
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
|
|
|
|
|
if isinstance(generator, list):
|
|
|
|
|
controls_latents = [
|
|
|
|
|
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
|
|
|
|
|
for i in range(chunk_control_video.shape[0])
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
controls_latents = [
|
|
|
|
|
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
|
|
|
|
|
for vid in chunk_control_video
|
|
|
|
|
]
|
|
|
|
|
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
|
|
|
|
|
|
|
|
|
|
controls_latents = (controls_latents - latents_mean) / latents_std
|
|
|
|
|
|
|
|
|
|
# Denoising loop
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
|
|
|
timesteps = self.scheduler.timesteps
|
|
|
|
|
self._num_timesteps = len(timesteps)
|
|
|
|
|
|
|
|
|
|
gt_velocity = (latents - cond_latent) * cond_mask
|
|
|
|
|
for i, t in enumerate(timesteps):
|
|
|
|
|
if self.interrupt:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
self._current_timestep = t.cpu().item()
|
|
|
|
|
|
|
|
|
|
# NOTE: assumes sigma(t) \in [0, 1]
|
|
|
|
|
sigma_t = (
|
|
|
|
|
torch.tensor(self.scheduler.sigmas[i].item())
|
|
|
|
|
.unsqueeze(0)
|
|
|
|
|
.to(device=device, dtype=transformer_dtype)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
|
|
|
|
in_latents = in_latents.to(transformer_dtype)
|
|
|
|
|
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
|
|
|
|
control_output = self.controlnet(
|
|
|
|
|
controls_latents=controls_latents,
|
|
|
|
|
latents=in_latents,
|
|
|
|
|
@@ -817,20 +911,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
)
|
|
|
|
|
control_blocks = control_output[0]
|
|
|
|
|
|
|
|
|
|
noise_pred = self.transformer(
|
|
|
|
|
hidden_states=in_latents,
|
|
|
|
|
timestep=in_timestep,
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
block_controlnet_hidden_states=control_blocks,
|
|
|
|
|
condition_mask=cond_mask,
|
|
|
|
|
padding_mask=padding_mask,
|
|
|
|
|
return_dict=False,
|
|
|
|
|
)[0]
|
|
|
|
|
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
|
|
|
|
noise_pred = self.transformer(
|
|
|
|
|
hidden_states=in_latents,
|
|
|
|
|
timestep=in_timestep,
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
block_controlnet_hidden_states=control_blocks,
|
|
|
|
|
condition_mask=cond_mask,
|
|
|
|
|
padding_mask=padding_mask,
|
|
|
|
|
return_dict=False,
|
|
|
|
|
)[0]
|
|
|
|
|
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
|
|
|
|
|
|
|
|
|
if self.do_classifier_free_guidance:
|
|
|
|
|
control_blocks = None
|
|
|
|
|
if controls_latents is not None and self.controlnet is not None:
|
|
|
|
|
if self.do_classifier_free_guidance:
|
|
|
|
|
control_output = self.controlnet(
|
|
|
|
|
controls_latents=controls_latents,
|
|
|
|
|
latents=in_latents,
|
|
|
|
|
@@ -843,46 +935,50 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
)
|
|
|
|
|
control_blocks = control_output[0]
|
|
|
|
|
|
|
|
|
|
noise_pred_neg = self.transformer(
|
|
|
|
|
hidden_states=in_latents,
|
|
|
|
|
timestep=in_timestep,
|
|
|
|
|
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
|
|
|
|
block_controlnet_hidden_states=control_blocks,
|
|
|
|
|
condition_mask=cond_mask,
|
|
|
|
|
padding_mask=padding_mask,
|
|
|
|
|
return_dict=False,
|
|
|
|
|
)[0]
|
|
|
|
|
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
|
|
|
|
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
|
|
|
|
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
|
|
|
|
noise_pred_neg = self.transformer(
|
|
|
|
|
hidden_states=in_latents,
|
|
|
|
|
timestep=in_timestep,
|
|
|
|
|
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
|
|
|
|
block_controlnet_hidden_states=control_blocks,
|
|
|
|
|
condition_mask=cond_mask,
|
|
|
|
|
padding_mask=padding_mask,
|
|
|
|
|
return_dict=False,
|
|
|
|
|
)[0]
|
|
|
|
|
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
|
|
|
|
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
|
|
|
|
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
|
|
|
|
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
|
|
|
|
|
|
if callback_on_step_end is not None:
|
|
|
|
|
callback_kwargs = {}
|
|
|
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
|
|
|
callback_kwargs[k] = locals()[k]
|
|
|
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
|
# call the callback, if provided
|
|
|
|
|
if callback_on_step_end is not None:
|
|
|
|
|
callback_kwargs = {}
|
|
|
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
|
|
|
callback_kwargs[k] = locals()[k]
|
|
|
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
|
|
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
|
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
|
|
|
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
|
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
|
|
|
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
|
|
|
|
|
|
|
|
|
# call the callback, if provided
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
|
|
progress_bar.update()
|
|
|
|
|
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
|
|
|
|
|
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
if XLA_AVAILABLE:
|
|
|
|
|
xm.mark_step()
|
|
|
|
|
if XLA_AVAILABLE:
|
|
|
|
|
xm.mark_step()
|
|
|
|
|
|
|
|
|
|
video_chunks.append(decode_latents(latents).detach().cpu())
|
|
|
|
|
latent_chunks.append(latents.detach().cpu())
|
|
|
|
|
|
|
|
|
|
self._current_timestep = None
|
|
|
|
|
|
|
|
|
|
if not output_type == "latent":
|
|
|
|
|
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
|
|
|
|
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
|
|
|
|
latents = latents * latents_std + latents_mean
|
|
|
|
|
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
|
|
|
|
video = self._match_num_frames(video, num_frames)
|
|
|
|
|
video_chunks = [
|
|
|
|
|
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
|
|
|
|
|
for chunk_idx, chunk in enumerate(video_chunks)
|
|
|
|
|
]
|
|
|
|
|
video = torch.cat(video_chunks, dim=2)
|
|
|
|
|
video = video[:, :, :num_frames_out, ...]
|
|
|
|
|
|
|
|
|
|
assert self.safety_checker is not None
|
|
|
|
|
self.safety_checker.to(device)
|
|
|
|
|
@@ -899,7 +995,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
|
|
|
|
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
|
|
|
|
else:
|
|
|
|
|
video = latents
|
|
|
|
|
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
|
|
|
|
latent_chunks = [
|
|
|
|
|
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
|
|
|
|
|
for chunk_idx, chunk in enumerate(latent_chunks)
|
|
|
|
|
]
|
|
|
|
|
video = torch.cat(latent_chunks, dim=2)
|
|
|
|
|
video = video[:, :, :latent_T, ...]
|
|
|
|
|
|
|
|
|
|
# Offload all models
|
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
|
|
@@ -908,19 +1010,3 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|
|
|
|
return (video,)
|
|
|
|
|
|
|
|
|
|
return CosmosPipelineOutput(frames=video)
|
|
|
|
|
|
|
|
|
|
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
|
|
|
|
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
|
|
|
|
return video
|
|
|
|
|
|
|
|
|
|
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
|
|
|
|
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
|
|
|
|
|
|
|
|
|
current_frames = video.shape[2]
|
|
|
|
|
if current_frames < target_num_frames:
|
|
|
|
|
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
|
|
|
|
video = torch.cat([video, pad], dim=2)
|
|
|
|
|
elif current_frames > target_num_frames:
|
|
|
|
|
video = video[:, :, :target_num_frames]
|
|
|
|
|
|
|
|
|
|
return video
|
|
|
|
|
|