Compare commits

..

1 Commits

Author SHA1 Message Date
Sayak Paul
022ac4ddf6 Fix torchrun command argument order in docs 2026-02-24 16:10:34 +05:30
14 changed files with 411 additions and 771 deletions

View File

@@ -46,20 +46,6 @@ output = pipe(
output.save("output.png")
```
## Cosmos2_5_TransferPipeline
[[autodoc]] Cosmos2_5_TransferPipeline
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
@@ -84,6 +70,12 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -94,15 +94,9 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/pipeline \
--output_path converted/transfer/2b/general/depth \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/models
# edge
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
@@ -126,15 +120,9 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/pipeline \
--output_path converted/transfer/2b/general/blur \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/models
# seg
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
@@ -142,14 +130,8 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/pipeline \
--output_path converted/transfer/2b/general/seg \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/models
```
"""

View File

@@ -329,11 +329,7 @@ class _HubKernelConfig:
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
@@ -733,7 +729,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
result = flash_attn_3_func(
out, lse, *_ = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -750,9 +746,7 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -1296,62 +1290,36 @@ def _flash_attention_3_hub_forward_op(
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
if wrapped_forward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
out, softmax_lse, *_ = wrapped_forward_fn(
query,
key,
value,
None,
None, # k_new, v_new
None, # qv
None, # out
None,
None,
None, # cu_seqlens_q/k/k_new
None,
None, # seqused_q/k
None,
None, # max_seqlen_q/k
None,
None,
None, # page_table, kv_batch_idx, leftpad_k
None,
None,
None, # rotary_cos/sin, seqlens_rotary
None,
None,
None, # q_descale, k_descale, v_descale
scale,
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=0,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, softmax_lse)
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
ctx._hub_kernel = func
return (out, lse) if return_lse else out
@@ -1360,49 +1328,54 @@ def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
"for context parallel execution."
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
query, key, value, out, softmax_lse = ctx.saved_tensors
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
softmax_lse,
None,
None, # cu_seqlens_q, cu_seqlens_k
None,
None, # seqused_q, seqused_k
None,
None, # max_seqlen_q, max_seqlen_k
grad_query,
grad_key,
grad_value,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
@@ -2703,7 +2676,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
result = flash_attn_3_varlen_func(
out, lse, *_ = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2713,13 +2686,7 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -191,12 +191,7 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
dim=1,
)
if condition_mask is not None:
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
else:
control_hidden_states = torch.cat(
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
)
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

View File

@@ -1836,7 +1836,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
update_model_card = kwargs.pop("update_model_card", False)
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# Generate modular pipeline card content
@@ -1849,7 +1848,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
is_pipeline=True,
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
is_modular=True,
update_model_card=update_model_card,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])

View File

@@ -50,7 +50,11 @@ This modular pipeline is composed of the following blocks:
{components_description} {configs_section}
{io_specification_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
"""
@@ -795,46 +799,6 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_params_markdown(params, header="Inputs"):
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
Suitable for model cards rendered on Hugging Face Hub.
Args:
params: list of InputParam or OutputParam objects to format
header: Header text (e.g. "Inputs" or "Outputs")
Returns:
A formatted markdown string, or empty string if params is empty.
"""
if not params:
return ""
def get_type_str(type_hint):
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
return " | ".join(type_strs)
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
lines = [f"**{header}:**\n"] if header else []
for param in params:
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"- `{name}` (`{type_str}`"
if hasattr(param, "required") and not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to `{param.default}`"
param_str += ")"
desc = param.description if param.description else "No description provided"
param_str += f": {desc}"
lines.append(param_str)
return "\n".join(lines)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
@@ -1091,7 +1055,8 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- io_specification_section: Input/Output specification (per-workflow or unified)
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
@@ -1110,6 +1075,15 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if block_desc:
blocks_desc_parts.append(f" - {block_desc}")
# add sub-blocks if any
if hasattr(block, "sub_blocks") and block.sub_blocks:
for sub_name, sub_block in block.sub_blocks.items():
sub_class = sub_block.__class__.__name__
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
if sub_desc:
blocks_desc_parts.append(f" - {sub_desc}")
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
components = getattr(blocks, "expected_components", [])
@@ -1135,76 +1109,63 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
# Branch on whether workflows are defined
has_workflows = getattr(blocks, "_workflow_map", None) is not None
inputs = blocks.inputs
outputs = blocks.outputs
if has_workflows:
workflow_map = blocks._workflow_map
parts = []
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
# use that as the shared output for all workflows
blocks_outputs = blocks.outputs
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
shared_outputs = (
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
)
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
parts.append("## Workflow Input Specification\n")
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
# Per-workflow details: show trigger inputs with full param descriptions
for wf_name, trigger_inputs in workflow_map.items():
trigger_input_names = set(trigger_inputs.keys())
try:
workflow_blocks = blocks.get_workflow(wf_name)
except Exception:
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append("*Could not resolve workflow blocks.*\n")
parts.append("</details>\n")
continue
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
wf_inputs = workflow_blocks.inputs
# Show only trigger inputs with full parameter descriptions
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
inputs_str = format_params_markdown(trigger_params, header=None)
parts.append(inputs_str if inputs_str else "No additional inputs required.")
parts.append("")
parts.append("</details>\n")
# Common Inputs & Outputs section (like non-workflow pipelines)
all_inputs = blocks.inputs
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
inputs_str = format_params_markdown(all_inputs, "Inputs")
outputs_str = format_params_markdown(all_outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
io_specification_section = "\n".join(parts)
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
trigger_inputs_section = ""
else:
# Unified I/O section (original behavior)
inputs = blocks.inputs
outputs = blocks.outputs
inputs_str = format_params_markdown(inputs, "Inputs")
outputs_str = format_params_markdown(outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1217,18 +1178,7 @@ This pipeline contains blocks that are selected at runtime based on inputs:
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if has_workflows:
# Derive tags from workflow names
workflow_names = set(blocks._workflow_map.keys())
if any("inpainting" in wf for wf in workflow_names):
tags.append("inpainting")
if any("image2image" in wf for wf in workflow_names):
tags.append("image-to-image")
if any("controlnet" in wf for wf in workflow_names):
tags.append("controlnet")
if any("text2image" in wf for wf in workflow_names):
tags.append("text-to-image")
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
@@ -1256,7 +1206,8 @@ This pipeline uses a {block_count}-block architecture that can be customized and
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"io_specification_section": io_specification_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -17,6 +17,9 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -51,13 +54,11 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
n_pad_frames = num_frames - video.shape[2]
if n_pad_frames > 0:
last_frame = video[:, :, -1:, :, :]
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
elif num_frames < video.shape[2]:
video = video[:, :, :num_frames, :, :]
return video
@@ -133,8 +134,8 @@ EXAMPLE_DOC_STRING = """
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
>>> # Transfer inference with controls.
>>> video = pipe(
... video=input_video[:num_frames],
... controls=controls,
... controls_conditioning_scale=1.0,
... prompt=prompt,
@@ -148,7 +149,7 @@ EXAMPLE_DOC_STRING = """
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
r"""
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
Pipeline for Cosmos Transfer2.5 base model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -165,14 +166,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
controlnet ([`CosmosControlNetModel`]):
ControlNet used to condition generation on control inputs.
"""
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
_optional_components = ["safety_checker", "controlnet"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
@@ -182,8 +181,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
controlnet: CosmosControlNetModel,
safety_checker: Optional[CosmosSafetyChecker] = None,
controlnet: Optional[CosmosControlNetModel],
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
@@ -385,11 +384,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
num_cond_latent_frames: int = 0,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -404,14 +402,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if latents is not None:
if latents.shape[1:] != shape[1:]:
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
latents = latents.to(device=device, dtype=dtype)
else:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
@@ -441,12 +435,16 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
@@ -456,7 +454,34 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
cond_indicator,
)
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def _encode_controls(
self,
controls: Optional[torch.Tensor],
height: int,
width: int,
num_frames: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator | list[torch.Generator] | None,
) -> Optional[torch.Tensor]:
if controls is None:
return None
control_video = self.video_processor.preprocess_video(controls, height, width)
control_video = _maybe_pad_video(control_video, num_frames)
control_video = control_video.to(device=device, dtype=self.vae.dtype)
control_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
]
control_latents = torch.cat(control_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
control_latents = (control_latents - latents_mean) / latents_std
return control_latents
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -464,25 +489,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
num_ar_conditional_frames=None,
num_ar_latent_conditional_frames=None,
num_frames_per_chunk=None,
num_frames=None,
conditional_frame_timestep=0.1,
):
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
)
if num_frames is not None and num_frames <= 0:
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
raise ValueError(
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
f"{conditional_frame_timestep}."
)
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -503,46 +512,6 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
raise ValueError(
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
)
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
if num_ar_latent_conditional_frames is not None:
num_ar_conditional_frames = max(
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
)
min_chunk_len = self.vae_scale_factor_temporal + 1
if num_frames_per_chunk < min_chunk_len:
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
num_frames_per_chunk = min_chunk_len
max_frames_by_rope = None
if getattr(self.transformer.config, "max_size", None) is not None:
max_frames_by_rope = max(
size // patch
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
)
if num_frames_per_chunk > max_frames_by_rope:
raise ValueError(
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
"Please reduce `num_frames_per_chunk`."
)
if num_ar_conditional_frames >= num_frames_per_chunk:
raise ValueError(
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
)
return num_frames_per_chunk
@property
def guidance_scale(self):
return self._guidance_scale
@@ -567,22 +536,23 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
controls: PipelineImageInput | List[PipelineImageInput],
controls_conditioning_scale: Union[float, List[float]] = 1.0,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
height: int = 704,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_frames_per_chunk: int = 93,
width: int | None = None,
num_frames: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
num_videos_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
num_videos_per_prompt: Optional[int] = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
controls_conditioning_scale: float | list[float] = 1.0,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -590,26 +560,24 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
num_ar_conditional_frames: Optional[int] = 1,
num_ar_latent_conditional_frames: Optional[int] = None,
):
r"""
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
are pre-computed.
The call function to the pipeline for generation. Supports three modes:
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
(default) then the number of output frames will match the input `controls`.
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
denoising loop. In addition, when auto-regressive inference is performed, the previous
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
inference loops.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Args:
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
Control image or video input used by the ControlNet.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
@@ -617,10 +585,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width (`int`, *optional*):
The width in pixels of the generated image. If not provided, this will be determined based on the
aspect ratio of the input and the provided height.
num_frames (`int`, *optional*):
Number of output frames. Defaults to `None` to output the same number of frames as the input
`controls`.
num_inference_steps (`int`, defaults to `36`):
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `3.0`):
@@ -634,9 +601,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
sampling using the supplied random `generator`.
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -659,18 +630,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
num_ar_latent_conditional_frames (`int`, *optional*):
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
Examples:
Returns:
@@ -690,40 +650,21 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if width is None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, list):
frame = frame[0]
if isinstance(frame, (torch.Tensor, np.ndarray)):
if frame.ndim == 5:
frame = frame[0, 0]
elif frame.ndim == 4:
frame = frame[0]
frame = image or video[0] if image or video else None
if frame is None and controls is not None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
frame = controls[0]
if isinstance(frame, PIL.Image.Image):
if frame is None:
width = int((height + 16) * (1280 / 720))
elif isinstance(frame, PIL.Image.Image):
width = int((height + 16) * (frame.width / frame.height))
else:
if frame.ndim != 3:
raise ValueError("`controls` must contain 3D frames in CHW format.")
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
num_frames_per_chunk = self.check_inputs(
prompt,
height,
width,
prompt_embeds,
callback_on_step_end_tensor_inputs,
num_ar_conditional_frames,
num_ar_latent_conditional_frames,
num_frames_per_chunk,
num_frames,
conditional_frame_timestep,
)
if num_ar_latent_conditional_frames is not None:
num_cond_latent_frames = num_ar_latent_conditional_frames
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
else:
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
@@ -768,137 +709,102 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
if getattr(self.transformer.config, "img_context_dim_in", None):
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
video = _maybe_pad_video(video, num_frames_out)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
controls_latents = None
if controls is not None:
controls_latents = self._encode_controls(
controls,
height=height,
width=width,
num_frames=num_frames,
dtype=transformer_dtype,
device=device,
generator=generator,
)
if num_videos_per_prompt > 1:
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
else:
encoder_hidden_states = prompt_embeds
neg_encoder_hidden_states = negative_prompt_embeds
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
control_video = self.video_processor.preprocess_video(controls, height, width)
if control_video.shape[0] != batch_size:
if control_video.shape[0] == 1:
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
else:
raise ValueError(
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
num_frames_out = control_video.shape[2]
if num_frames is not None:
num_frames_out = min(num_frames_out, num_frames)
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
# chunk information
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
chunk_idxs = [
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
]
video_chunks = []
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
def decode_latents(latents):
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
return video
latents_arg = latents
initial_num_cond_latent_frames = 0
latent_chunks = []
num_chunks = len(chunk_idxs)
total_steps = num_inference_steps * num_chunks
with self.progress_bar(total=total_steps) as progress_bar:
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
if chunk_idx == 0:
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
else:
prev_output = video_chunks[-1].clone()
if num_ar_conditional_frames > 0:
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
else:
prev_output.fill_(-1)
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=chunk_video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=self.transformer.config.in_channels - 1,
height=height,
width=width,
num_frames_in=chunk_video.shape[2],
num_frames_out=num_frames_per_chunk,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
num_cond_latent_frames=initial_num_cond_latent_frames
if chunk_idx == 0
else num_cond_latent_frames,
latents=latents_arg,
)
cond_mask = cond_mask.to(transformer_dtype)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
device=device, dtype=self.vae.dtype
)
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
if isinstance(generator, list):
controls_latents = [
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
for i in range(chunk_control_video.shape[0])
]
else:
controls_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
for vid in chunk_control_video
]
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
controls_latents = (controls_latents - latents_mean) / latents_std
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
gt_velocity = (latents - cond_latent) * cond_mask
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -911,18 +817,20 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
if self.do_classifier_free_guidance:
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -935,50 +843,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
video_chunks.append(decode_latents(latents).detach().cpu())
latent_chunks.append(latents.detach().cpu())
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
video_chunks = [
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(video_chunks)
]
video = torch.cat(video_chunks, dim=2)
video = video[:, :, :num_frames_out, ...]
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
assert self.safety_checker is not None
self.safety_checker.to(device)
@@ -995,13 +899,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
latent_chunks = [
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(latent_chunks)
]
video = torch.cat(latent_chunks, dim=2)
video = video[:, :, :latent_T, ...]
video = latents
# Offload all models
self.maybe_free_model_hooks()
@@ -1010,3 +908,19 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -699,13 +699,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
if latents.ndim == 5:
# conditioning_mask needs to the same shape as latents in two stages generation.
batch_size, _, num_frames, height, width = latents.shape
mask_shape = (batch_size, 1, num_frames, height, width)
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
@@ -714,9 +710,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
else:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)

View File

@@ -276,7 +276,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 0
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):

View File

@@ -107,7 +107,6 @@ def load_or_create_model_card(
widget: list[dict] | None = None,
inference: bool | None = None,
is_modular: bool = False,
update_model_card: bool = False,
) -> ModelCard:
"""
Loads or creates a model card.
@@ -134,9 +133,6 @@ def load_or_create_model_card(
`load_or_create_model_card` from a training script.
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
When True, uses model_description as-is without additional template formatting.
update_model_card: (`bool`, optional): When True, regenerates the model card content even if one
already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only
supported for modular pipelines (i.e., `is_modular=True`).
"""
if not is_jinja_available():
raise ValueError(
@@ -145,17 +141,9 @@ def load_or_create_model_card(
" To install it, please run `pip install Jinja2`."
)
if update_model_card and not is_modular:
raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).")
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
# For modular pipelines, regenerate card content when requested (preserve existing metadata)
if update_model_card and is_modular and model_description is not None:
existing_data = model_card.data
model_card = ModelCard(model_description)
model_card.data = existing_data
except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a model card from template
if from_training:

View File

@@ -131,26 +131,6 @@ class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
self.assertIsInstance(output[0], list)
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
def test_condition_mask_changes_output(self):
"""Test that condition mask affects control outputs."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_no_mask = dict(inputs_dict)
inputs_no_mask["condition_mask"] = torch.zeros_like(inputs_dict["condition_mask"])
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
output_with_mask = model(**inputs_dict)
self.assertEqual(len(output_no_mask.control_block_samples), len(output_with_mask.control_block_samples))
for no_mask_tensor, with_mask_tensor in zip(
output_no_mask.control_block_samples, output_with_mask.control_block_samples
):
self.assertFalse(torch.allclose(no_mask_tensor, with_mask_tensor))
def test_conditioning_scale_single(self):
"""Test that a single conditioning scale is broadcast to all blocks."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -454,7 +454,8 @@ class TestModularModelCardContent:
"blocks_description",
"components_description",
"configs_section",
"io_specification_section",
"inputs_description",
"outputs_description",
"trigger_inputs_section",
"tags",
]
@@ -551,19 +552,18 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(inputs=inputs)
content = generate_modular_model_card_content(blocks)
io_section = content["io_specification_section"]
assert "**Inputs:**" in io_section
assert "prompt" in io_section
assert "num_steps" in io_section
assert "*optional*" in io_section
assert "defaults to `50`" in io_section
assert "**Required:**" in content["inputs_description"]
assert "**Optional:**" in content["inputs_description"]
assert "prompt" in content["inputs_description"]
assert "num_steps" in content["inputs_description"]
assert "default: `50`" in content["inputs_description"]
def test_inputs_description_empty(self):
"""Test handling of pipelines without specific inputs."""
blocks = self.create_mock_blocks(inputs=[])
content = generate_modular_model_card_content(blocks)
assert "No specific inputs defined" in content["io_specification_section"]
assert "No specific inputs defined" in content["inputs_description"]
def test_outputs_description_formatting(self):
"""Test that outputs are correctly formatted."""
@@ -573,16 +573,15 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(outputs=outputs)
content = generate_modular_model_card_content(blocks)
io_section = content["io_specification_section"]
assert "images" in io_section
assert "Generated images" in io_section
assert "images" in content["outputs_description"]
assert "Generated images" in content["outputs_description"]
def test_outputs_description_empty(self):
"""Test handling of pipelines without specific outputs."""
blocks = self.create_mock_blocks(outputs=[])
content = generate_modular_model_card_content(blocks)
assert "Standard pipeline outputs" in content["io_specification_section"]
assert "Standard pipeline outputs" in content["outputs_description"]
def test_trigger_inputs_section_with_triggers(self):
"""Test that trigger inputs section is generated when present."""

View File

@@ -55,7 +55,7 @@ class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2_5_TransferWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"controls"})
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
@@ -176,19 +176,15 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
else:
generator = torch.Generator(device=device).manual_seed(seed)
controls_generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"controls": [torch.randn(3, 32, 32, generator=controls_generator) for _ in range(5)],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"num_frames": 3,
"num_frames_per_chunk": 16,
"max_sequence_length": 16,
"output_type": "pt",
}
@@ -216,56 +212,6 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_autoregressive_multi_chunk(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames"] = 5
inputs["num_frames_per_chunk"] = 3
inputs["num_ar_conditional_frames"] = 1
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_autoregressive_multi_chunk_no_condition_frames(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames"] = 5
inputs["num_frames_per_chunk"] = 3
inputs["num_ar_conditional_frames"] = 0
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_num_frames_per_chunk_above_rope_raises(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames_per_chunk"] = 17
with self.assertRaisesRegex(ValueError, "too large for RoPE setting"):
pipe(**inputs)
def test_inference_with_controls(self):
"""Test inference with control inputs (ControlNet)."""
device = "cpu"
@@ -276,13 +222,13 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["controls"] = [torch.randn(3, 32, 32) for _ in range(5)] # list of 5 frames (C, H, W)
# Add control video input - should be a video tensor
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
inputs["controls_conditioning_scale"] = 1.0
inputs["num_frames"] = None
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_callback_inputs(self):

View File

@@ -24,8 +24,7 @@ from diffusers import (
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
@@ -175,15 +174,6 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return components
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
upsampler = LTX2LatentUpsamplerModel(
in_channels=in_channels,
mid_channels=mid_channels,
num_blocks_per_stage=num_blocks_per_stage,
)
return upsampler
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -297,60 +287,5 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference_with_upsampler(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
inputs["latents"] = upscaled_video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
]
)
expected_audio_slice = torch.tensor(
[
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)