Compare commits

..

3 Commits

Author SHA1 Message Date
Dhruv Nair
96f97f8214 update 2026-02-24 09:45:26 +01:00
Dhruv Nair
f8b95ff263 Merge branch 'main' into custom-device-map 2026-02-24 09:20:16 +01:00
DN6
cfff46069d add custom mesh support 2026-02-02 13:12:09 +05:30
14 changed files with 464 additions and 668 deletions

View File

@@ -46,20 +46,6 @@ output = pipe(
output.save("output.png")
```
## Cosmos2_5_TransferPipeline
[[autodoc]] Cosmos2_5_TransferPipeline
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
@@ -84,6 +70,12 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -111,7 +111,7 @@ if __name__ == "__main__":
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
```bash
torchrun --nproc_per_node=2 run_distributed.py
torchrun run_distributed.py --nproc_per_node=2
```
## device_map

View File

@@ -94,15 +94,9 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/pipeline \
--output_path converted/transfer/2b/general/depth \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/models
# edge
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
@@ -126,15 +120,9 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/pipeline \
--output_path converted/transfer/2b/general/blur \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/models
# seg
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
@@ -142,14 +130,8 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/pipeline \
--output_path converted/transfer/2b/general/seg \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/models
```
"""

View File

@@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -62,8 +62,6 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -75,18 +73,8 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
except (ImportError, OSError, RuntimeError) as e:
# Handle ABI mismatch or other import failures gracefully.
# This can happen when flash_attn was compiled against a different PyTorch version.
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
_CAN_USE_FLASH_ATTN = False
flash_attn_func = None
flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else:
flash_attn_func = None
flash_attn_varlen_func = None
@@ -95,47 +83,26 @@ else:
if _CAN_USE_FLASH_ATTN_3:
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLASH_ATTN_3 = False
flash_attn_3_func = None
flash_attn_3_varlen_func = None
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
try:
from aiter import flash_attn_func as aiter_flash_attn_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
_CAN_USE_AITER_ATTN = False
aiter_flash_attn_func = None
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
try:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
_CAN_USE_SAGE_ATTN = False
sageattn = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_varlen = None
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
@@ -146,48 +113,26 @@ else:
if _CAN_USE_FLEX_ATTN:
try:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLEX_ATTN = False
flex_attention = None
else:
flex_attention = None
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
if _CAN_USE_NPU_ATTN:
try:
from torch_npu import npu_fusion_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
_CAN_USE_NPU_ATTN = False
npu_fusion_attention = None
from torch_npu import npu_fusion_attention
else:
npu_fusion_attention = None
if _CAN_USE_XLA_ATTN:
try:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
_CAN_USE_XLA_ATTN = False
xla_flash_attention = None
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
else:
xla_flash_attention = None
if _CAN_USE_XFORMERS_ATTN:
try:
import xformers.ops as xops
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
_CAN_USE_XFORMERS_ATTN = False
xops = None
import xformers.ops as xops
else:
xops = None
@@ -213,6 +158,8 @@ else:
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
@@ -329,11 +276,7 @@ class _HubKernelConfig:
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
@@ -733,7 +676,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
result = flash_attn_3_func(
out, lse, *_ = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -750,9 +693,7 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -1296,62 +1237,36 @@ def _flash_attention_3_hub_forward_op(
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
if wrapped_forward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
out, softmax_lse, *_ = wrapped_forward_fn(
query,
key,
value,
None,
None, # k_new, v_new
None, # qv
None, # out
None,
None,
None, # cu_seqlens_q/k/k_new
None,
None, # seqused_q/k
None,
None, # max_seqlen_q/k
None,
None,
None, # page_table, kv_batch_idx, leftpad_k
None,
None,
None, # rotary_cos/sin, seqlens_rotary
None,
None,
None, # q_descale, k_descale, v_descale
scale,
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=0,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, softmax_lse)
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
ctx._hub_kernel = func
return (out, lse) if return_lse else out
@@ -1360,49 +1275,54 @@ def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
"for context parallel execution."
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
query, key, value, out, softmax_lse = ctx.saved_tensors
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
softmax_lse,
None,
None, # cu_seqlens_q, cu_seqlens_k
None,
None, # seqused_q, seqused_k
None,
None, # max_seqlen_q, max_seqlen_k
grad_query,
grad_key,
grad_value,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
@@ -2703,7 +2623,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
result = flash_attn_3_varlen_func(
out, lse, *_ = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2713,13 +2633,7 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -191,12 +191,7 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
dim=1,
)
if condition_mask is not None:
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
else:
control_hidden_states = torch.cat(
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
)
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -1633,14 +1633,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name, None)
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
# fall back to default_blocks_name
if blocks_class is None or not blocks_class.block_classes:
blocks_class_name = self.default_blocks_name
blocks_class = getattr(diffusers_module, blocks_class_name)
if blocks_class is not None:
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")

View File

@@ -17,6 +17,9 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -51,13 +54,11 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
n_pad_frames = num_frames - video.shape[2]
if n_pad_frames > 0:
last_frame = video[:, :, -1:, :, :]
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
elif num_frames < video.shape[2]:
video = video[:, :, :num_frames, :, :]
return video
@@ -133,8 +134,8 @@ EXAMPLE_DOC_STRING = """
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
>>> # Transfer inference with controls.
>>> video = pipe(
... video=input_video[:num_frames],
... controls=controls,
... controls_conditioning_scale=1.0,
... prompt=prompt,
@@ -148,7 +149,7 @@ EXAMPLE_DOC_STRING = """
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
r"""
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
Pipeline for Cosmos Transfer2.5 base model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -165,14 +166,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
controlnet ([`CosmosControlNetModel`]):
ControlNet used to condition generation on control inputs.
"""
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
_optional_components = ["safety_checker", "controlnet"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
@@ -182,8 +181,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
controlnet: CosmosControlNetModel,
safety_checker: Optional[CosmosSafetyChecker] = None,
controlnet: Optional[CosmosControlNetModel],
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
@@ -385,11 +384,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
num_cond_latent_frames: int = 0,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -404,14 +402,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if latents is not None:
if latents.shape[1:] != shape[1:]:
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
latents = latents.to(device=device, dtype=dtype)
else:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
@@ -441,12 +435,16 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
@@ -456,7 +454,34 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
cond_indicator,
)
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def _encode_controls(
self,
controls: Optional[torch.Tensor],
height: int,
width: int,
num_frames: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator | list[torch.Generator] | None,
) -> Optional[torch.Tensor]:
if controls is None:
return None
control_video = self.video_processor.preprocess_video(controls, height, width)
control_video = _maybe_pad_video(control_video, num_frames)
control_video = control_video.to(device=device, dtype=self.vae.dtype)
control_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
]
control_latents = torch.cat(control_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
control_latents = (control_latents - latents_mean) / latents_std
return control_latents
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -464,25 +489,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
num_ar_conditional_frames=None,
num_ar_latent_conditional_frames=None,
num_frames_per_chunk=None,
num_frames=None,
conditional_frame_timestep=0.1,
):
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
)
if num_frames is not None and num_frames <= 0:
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
raise ValueError(
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
f"{conditional_frame_timestep}."
)
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -503,46 +512,6 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
raise ValueError(
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
)
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
if num_ar_latent_conditional_frames is not None:
num_ar_conditional_frames = max(
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
)
min_chunk_len = self.vae_scale_factor_temporal + 1
if num_frames_per_chunk < min_chunk_len:
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
num_frames_per_chunk = min_chunk_len
max_frames_by_rope = None
if getattr(self.transformer.config, "max_size", None) is not None:
max_frames_by_rope = max(
size // patch
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
)
if num_frames_per_chunk > max_frames_by_rope:
raise ValueError(
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
"Please reduce `num_frames_per_chunk`."
)
if num_ar_conditional_frames >= num_frames_per_chunk:
raise ValueError(
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
)
return num_frames_per_chunk
@property
def guidance_scale(self):
return self._guidance_scale
@@ -567,22 +536,23 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
controls: PipelineImageInput | List[PipelineImageInput],
controls_conditioning_scale: Union[float, List[float]] = 1.0,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
height: int = 704,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_frames_per_chunk: int = 93,
width: int | None = None,
num_frames: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
num_videos_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
num_videos_per_prompt: Optional[int] = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
controls_conditioning_scale: float | list[float] = 1.0,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -590,26 +560,24 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
num_ar_conditional_frames: Optional[int] = 1,
num_ar_latent_conditional_frames: Optional[int] = None,
):
r"""
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
are pre-computed.
The call function to the pipeline for generation. Supports three modes:
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
(default) then the number of output frames will match the input `controls`.
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
denoising loop. In addition, when auto-regressive inference is performed, the previous
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
inference loops.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Args:
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
Control image or video input used by the ControlNet.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
@@ -617,10 +585,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width (`int`, *optional*):
The width in pixels of the generated image. If not provided, this will be determined based on the
aspect ratio of the input and the provided height.
num_frames (`int`, *optional*):
Number of output frames. Defaults to `None` to output the same number of frames as the input
`controls`.
num_inference_steps (`int`, defaults to `36`):
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `3.0`):
@@ -634,9 +601,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
sampling using the supplied random `generator`.
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -659,18 +630,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
num_ar_latent_conditional_frames (`int`, *optional*):
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
Examples:
Returns:
@@ -690,40 +650,21 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if width is None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, list):
frame = frame[0]
if isinstance(frame, (torch.Tensor, np.ndarray)):
if frame.ndim == 5:
frame = frame[0, 0]
elif frame.ndim == 4:
frame = frame[0]
frame = image or video[0] if image or video else None
if frame is None and controls is not None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
frame = controls[0]
if isinstance(frame, PIL.Image.Image):
if frame is None:
width = int((height + 16) * (1280 / 720))
elif isinstance(frame, PIL.Image.Image):
width = int((height + 16) * (frame.width / frame.height))
else:
if frame.ndim != 3:
raise ValueError("`controls` must contain 3D frames in CHW format.")
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
num_frames_per_chunk = self.check_inputs(
prompt,
height,
width,
prompt_embeds,
callback_on_step_end_tensor_inputs,
num_ar_conditional_frames,
num_ar_latent_conditional_frames,
num_frames_per_chunk,
num_frames,
conditional_frame_timestep,
)
if num_ar_latent_conditional_frames is not None:
num_cond_latent_frames = num_ar_latent_conditional_frames
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
else:
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
@@ -768,137 +709,102 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
if getattr(self.transformer.config, "img_context_dim_in", None):
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
video = _maybe_pad_video(video, num_frames_out)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
controls_latents = None
if controls is not None:
controls_latents = self._encode_controls(
controls,
height=height,
width=width,
num_frames=num_frames,
dtype=transformer_dtype,
device=device,
generator=generator,
)
if num_videos_per_prompt > 1:
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
else:
encoder_hidden_states = prompt_embeds
neg_encoder_hidden_states = negative_prompt_embeds
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
control_video = self.video_processor.preprocess_video(controls, height, width)
if control_video.shape[0] != batch_size:
if control_video.shape[0] == 1:
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
else:
raise ValueError(
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
num_frames_out = control_video.shape[2]
if num_frames is not None:
num_frames_out = min(num_frames_out, num_frames)
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
# chunk information
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
chunk_idxs = [
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
]
video_chunks = []
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
def decode_latents(latents):
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
return video
latents_arg = latents
initial_num_cond_latent_frames = 0
latent_chunks = []
num_chunks = len(chunk_idxs)
total_steps = num_inference_steps * num_chunks
with self.progress_bar(total=total_steps) as progress_bar:
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
if chunk_idx == 0:
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
else:
prev_output = video_chunks[-1].clone()
if num_ar_conditional_frames > 0:
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
else:
prev_output.fill_(-1)
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=chunk_video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=self.transformer.config.in_channels - 1,
height=height,
width=width,
num_frames_in=chunk_video.shape[2],
num_frames_out=num_frames_per_chunk,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
num_cond_latent_frames=initial_num_cond_latent_frames
if chunk_idx == 0
else num_cond_latent_frames,
latents=latents_arg,
)
cond_mask = cond_mask.to(transformer_dtype)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
device=device, dtype=self.vae.dtype
)
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
if isinstance(generator, list):
controls_latents = [
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
for i in range(chunk_control_video.shape[0])
]
else:
controls_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
for vid in chunk_control_video
]
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
controls_latents = (controls_latents - latents_mean) / latents_std
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
gt_velocity = (latents - cond_latent) * cond_mask
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -911,18 +817,20 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
if self.do_classifier_free_guidance:
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -935,50 +843,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
video_chunks.append(decode_latents(latents).detach().cpu())
latent_chunks.append(latents.detach().cpu())
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
video_chunks = [
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(video_chunks)
]
video = torch.cat(video_chunks, dim=2)
video = video[:, :, :num_frames_out, ...]
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
assert self.safety_checker is not None
self.safety_checker.to(device)
@@ -995,13 +899,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
latent_chunks = [
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(latent_chunks)
]
video = torch.cat(latent_chunks, dim=2)
video = video[:, :, :latent_T, ...]
video = latents
# Offload all models
self.maybe_free_model_hooks()
@@ -1010,3 +908,19 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -276,7 +276,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 0
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):

View File

@@ -131,26 +131,6 @@ class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
self.assertIsInstance(output[0], list)
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
def test_condition_mask_changes_output(self):
"""Test that condition mask affects control outputs."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_no_mask = dict(inputs_dict)
inputs_no_mask["condition_mask"] = torch.zeros_like(inputs_dict["condition_mask"])
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
output_with_mask = model(**inputs_dict)
self.assertEqual(len(output_no_mask.control_block_samples), len(output_with_mask.control_block_samples))
for no_mask_tensor, with_mask_tensor in zip(
output_no_mask.control_block_samples, output_with_mask.control_block_samples
):
self.assertFalse(torch.allclose(no_mask_tensor, with_mask_tensor))
def test_conditioning_scale_single(self):
"""Test that a single conditioning scale is broadcast to all blocks."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -699,27 +699,3 @@ class TestLoadComponentsSkipBehavior:
# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None
class TestModularPipelineInitFallback:
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""
def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path):
# 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks)
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
t2i_blocks = pipe.blocks.get_workflow("text2image")
assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks"
# 2. Use init_pipeline to create a new pipeline from the workflow blocks
t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
# 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks"
save_dir = str(tmp_path / "pipeline")
t2i_pipe.save_pretrained(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
# 4. Verify it fell back to default_blocks_name and has correct blocks
assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__
assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__
assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks)

View File

@@ -55,7 +55,7 @@ class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2_5_TransferWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"controls"})
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
@@ -176,19 +176,15 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
else:
generator = torch.Generator(device=device).manual_seed(seed)
controls_generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"controls": [torch.randn(3, 32, 32, generator=controls_generator) for _ in range(5)],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"num_frames": 3,
"num_frames_per_chunk": 16,
"max_sequence_length": 16,
"output_type": "pt",
}
@@ -216,56 +212,6 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_autoregressive_multi_chunk(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames"] = 5
inputs["num_frames_per_chunk"] = 3
inputs["num_ar_conditional_frames"] = 1
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_autoregressive_multi_chunk_no_condition_frames(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames"] = 5
inputs["num_frames_per_chunk"] = 3
inputs["num_ar_conditional_frames"] = 0
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_num_frames_per_chunk_above_rope_raises(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames_per_chunk"] = 17
with self.assertRaisesRegex(ValueError, "too large for RoPE setting"):
pipe(**inputs)
def test_inference_with_controls(self):
"""Test inference with control inputs (ControlNet)."""
device = "cpu"
@@ -276,13 +222,13 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["controls"] = [torch.randn(3, 32, 32) for _ in range(5)] # list of 5 frames (C, H, W)
# Add control video input - should be a video tensor
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
inputs["controls_conditioning_scale"] = 1.0
inputs["num_frames"] = None
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_callback_inputs(self):