Compare commits

...

5 Commits

Author SHA1 Message Date
Sayak Paul
ed77a246c9 [modular] add tests for robust model loading. (#13120)
* add tests for robust model loading.

* apply review feedback.
2026-02-12 10:04:29 +05:30
Miguel Martin
a1816166a5 Cosmos Transfer2.5 inference pipeline: general/{seg, depth, blur, edge} (#13066)
* initial conversion script

* cosmos control net block

* CosmosAttention

* base model conversion

* wip

* pipeline updates

* convert controlnet

* pipeline: working without controls

* wip

* debugging

* Almost working

* temp

* control working

* cleanup + detail on neg_encoder_hidden_states

* convert edge

* pos emb for control latents

* convert all chkpts

* resolve TODOs

* remove prints

* Docs

* add siglip image reference encoder

* Add unit tests

* controlnet: add duplicate layers

* Additional tests

* skip less

* skip less

* remove image_ref

* minor

* docs

* remove skipped test in transfer

* Don't crash process

* formatting

* revert some changes

* remove skipped test

* make style

* Address comment + fix example

* CosmosAttnProcessor2_0 revert + CosmosAttnProcessor2_5 changes

* make style

* make fix-copies
2026-02-11 18:33:09 -10:00
David El Malih
06a0f98e6e docs: improve docstring scheduling_flow_match_euler_discrete.py (#13127)
Improve docstring scheduling flow match euler discrete
2026-02-11 16:39:55 -08:00
Jared Wen
d32483913a [Fix]Allow prompt and prior_token_ids to be provided simultaneously in GlmImagePipeline (#13092)
* allow loose input

Signed-off-by: JaredforReal <w13431838023@gmail.com>

* add tests

Signed-off-by: JaredforReal <w13431838023@gmail.com>

* format test_glm_image

Signed-off-by: JaredforReal <w13431838023@gmail.com>

---------

Signed-off-by: JaredforReal <w13431838023@gmail.com>
2026-02-11 08:29:36 -10:00
David El Malih
64e2adf8f5 docs: improve docstring scheduling_edm_dpmsolver_multistep.py (#13122)
Improve docstring scheduling edm dpmsolver multistep
2026-02-11 08:59:33 -08:00
21 changed files with 2725 additions and 96 deletions

View File

@@ -78,12 +78,67 @@ python scripts/convert_cosmos_to_diffusers.py \
--save_pipeline
```
# Cosmos 2.5 Transfer
Download checkpoint
```bash
hf download nvidia/Cosmos-Transfer2.5-2B
```
Convert checkpoint
```bash
# depth
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/depth/626e6618-bfcd-4d9a-a077-1409e2ce353f_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth \
--save_pipeline
# edge
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/edge/pipeline \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/edge/models
# blur
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/blur/ba2f44f2-c726-4fe7-949f-597069d9b91c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur \
--save_pipeline
# seg
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg \
--save_pipeline
```
"""
import argparse
import pathlib
import sys
from typing import Any, Dict
from typing import Any, Dict, Optional
import torch
from accelerate import init_empty_weights
@@ -95,6 +150,7 @@ from diffusers import (
AutoencoderKLWan,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosControlNetModel,
CosmosTextToWorldPipeline,
CosmosTransformer3DModel,
CosmosVideoToWorldPipeline,
@@ -103,6 +159,7 @@ from diffusers import (
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -356,8 +413,62 @@ TRANSFORMER_CONFIGS = {
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
"Cosmos-2.5-Transfer-General-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
"controlnet_block_every_n": 7,
"img_context_dim_in": 1152,
"img_context_dim_out": 2048,
"img_context_num_tokens": 256,
},
}
CONTROLNET_CONFIGS = {
"Cosmos-2.5-Transfer-General-2B": {
"n_controlnet_blocks": 4,
"model_channels": 2048,
"in_channels": 130,
"latent_channels": 18, # (16 latent + 1 condition_mask) + 1 padding_mask = 18
"num_attention_heads": 16,
"attention_head_dim": 128,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"patch_size": (1, 2, 2),
"max_size": (128, 240, 240),
"rope_scale": (1.0, 3.0, 3.0),
"extra_pos_embed_type": None,
"img_context_dim_in": 1152,
"img_context_dim_out": 2048,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}
CONTROLNET_KEYS_RENAME_DICT = {
**TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0,
"blocks": "blocks",
"control_embedder.proj.1": "patch_embed.proj",
}
CONTROLNET_SPECIAL_KEYS_REMAP = {**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0}
VAE_KEYS_RENAME_DICT = {
"down.0": "down_blocks.0",
"down.1": "down_blocks.1",
@@ -447,9 +558,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
def convert_transformer(
transformer_type: str,
state_dict: Optional[Dict[str, Any]] = None,
weights_only: bool = True,
):
PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
if "Cosmos-1.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
@@ -467,23 +581,29 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
config = TRANSFORMER_CONFIGS[transformer_type]
transformer = CosmosTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
old2new = {}
new2old = {}
for key in list(state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)
assert new_key not in new2old, f"new key {new_key} already mapped"
assert key not in old2new, f"old key {key} already mapped"
old2new[key] = new_key
new2old[new_key] = key
update_state_dict_(state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for key in list(state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
handler_fn_inplace(key, state_dict)
expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
mapped_keys = set(state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
@@ -497,10 +617,86 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
print(k)
sys.exit(2)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
transformer.load_state_dict(state_dict, strict=True, assign=True)
return transformer
def convert_controlnet(
transformer_type: str,
control_state_dict: Dict[str, Any],
base_state_dict: Dict[str, Any],
weights_only: bool = True,
):
"""
Convert controlnet weights.
Args:
transformer_type: The type of transformer/controlnet
control_state_dict: State dict containing controlnet-specific weights
base_state_dict: State dict containing base transformer weights (for shared modules)
weights_only: Whether to use weights_only loading
"""
if transformer_type not in CONTROLNET_CONFIGS:
raise AssertionError(f"{transformer_type} does not define a ControlNet config")
PREFIX_KEY = "net."
# Process control-specific keys
for key in list(control_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(control_state_dict, key, new_key)
for key in list(control_state_dict.keys()):
for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, control_state_dict)
# Copy shared weights from base transformer to controlnet
# These are the duplicated modules: patch_embed_base, time_embed, learnable_pos_embed, img_context_proj, crossattn_proj
shared_module_mappings = {
# transformer key prefix -> controlnet key prefix
"patch_embed.": "patch_embed_base.",
"time_embed.": "time_embed.",
"learnable_pos_embed.": "learnable_pos_embed.",
"img_context_proj.": "img_context_proj.",
"crossattn_proj.": "crossattn_proj.",
}
for key in list(base_state_dict.keys()):
for transformer_prefix, controlnet_prefix in shared_module_mappings.items():
if key.startswith(transformer_prefix):
controlnet_key = controlnet_prefix + key[len(transformer_prefix) :]
control_state_dict[controlnet_key] = base_state_dict[key].clone()
print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True)
break
cfg = CONTROLNET_CONFIGS[transformer_type]
controlnet = CosmosControlNetModel(**cfg)
expected_keys = set(controlnet.state_dict().keys())
mapped_keys = set(control_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"WARNING: missing controlnet keys ({len(missing_keys)}):", file=sys.stderr, flush=True)
for k in sorted(missing_keys):
print(k, file=sys.stderr)
sys.exit(3)
if unexpected_keys:
print(f"WARNING: unexpected controlnet keys ({len(unexpected_keys)}):", file=sys.stderr, flush=True)
for k in sorted(unexpected_keys):
print(k, file=sys.stderr)
sys.exit(4)
controlnet.load_state_dict(control_state_dict, strict=True, assign=True)
return controlnet
def convert_vae(vae_type: str):
model_name = VAE_CONFIGS[vae_type]["name"]
snapshot_directory = snapshot_download(model_name, repo_type="model")
@@ -586,7 +782,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5(args, transformer, vae):
def save_pipeline_cosmos2_5_predict(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
@@ -614,6 +810,35 @@ def save_pipeline_cosmos2_5(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5_transfer(args, transformer, controlnet, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)
pipe = Cosmos2_5_TransferPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
@@ -642,18 +867,61 @@ if __name__ == "__main__":
args = get_args()
transformer = None
controlnet = None
dtype = DTYPE_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
raw_state_dict = None
if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
raw_state_dict = get_state_dict(
torch.load(args.transformer_ckpt_path, map_location="cpu", weights_only=weights_only)
)
if raw_state_dict is not None:
if "Transfer" in args.transformer_type:
base_state_dict = {}
control_state_dict = {}
for k, v in raw_state_dict.items():
plain_key = k.removeprefix("net.") if k.startswith("net.") else k
if "control" in plain_key.lower():
control_state_dict[k] = v
else:
base_state_dict[k] = v
assert len(base_state_dict.keys() & control_state_dict.keys()) == 0
# Convert transformer first to get the processed base state dict
transformer = convert_transformer(
args.transformer_type, state_dict=base_state_dict, weights_only=weights_only
)
transformer = transformer.to(dtype=dtype)
# Get converted transformer state dict to copy shared weights to controlnet
converted_base_state_dict = transformer.state_dict()
# Convert controlnet with both control-specific and shared weights from transformer
controlnet = convert_controlnet(
args.transformer_type, control_state_dict, converted_base_state_dict, weights_only=weights_only
)
controlnet = controlnet.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(
pathlib.Path(args.output_path) / "transformer", safe_serialization=True, max_shard_size="5GB"
)
controlnet.save_pretrained(
pathlib.Path(args.output_path) / "controlnet", safe_serialization=True, max_shard_size="5GB"
)
else:
transformer = convert_transformer(
args.transformer_type, state_dict=raw_state_dict, weights_only=weights_only
)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
@@ -667,6 +935,8 @@ if __name__ == "__main__":
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
else:
vae = None
if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
@@ -678,6 +948,11 @@ if __name__ == "__main__":
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
if "Predict" in args.transformer_type:
save_pipeline_cosmos2_5_predict(args, transformer, vae)
elif "Transfer" in args.transformer_type:
save_pipeline_cosmos2_5_transfer(args, transformer, None, vae)
else:
raise AssertionError(f"{args.transformer_type} not supported")
else:
raise AssertionError(f"{args.transformer_type} not supported")

View File

@@ -221,6 +221,7 @@ else:
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
"CosmosControlNetModel",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
@@ -485,6 +486,7 @@ else:
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2_5_TransferPipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
@@ -992,6 +994,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosControlNetModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
@@ -1226,6 +1229,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2_5_TransferPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,

View File

@@ -54,6 +54,7 @@ if is_torch_available():
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
"HunyuanDiT2DControlNetModel",
@@ -175,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosControlNetModel,
FluxControlNetModel,
FluxMultiControlNetModel,
HunyuanDiT2DControlNetModel,

View File

@@ -3,6 +3,7 @@ from ...utils import is_flax_available, is_torch_available
if is_torch_available():
from .controlnet import ControlNetModel, ControlNetOutput
from .controlnet_cosmos import CosmosControlNetModel
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
from .controlnet_hunyuan import (
HunyuanControlNetOutput,

View File

@@ -0,0 +1,312 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, is_torchvision_available, logging
from ..modeling_utils import ModelMixin
from ..transformers.transformer_cosmos import (
CosmosEmbedding,
CosmosLearnablePositionalEmbed,
CosmosPatchEmbed,
CosmosRotaryPosEmbed,
CosmosTransformerBlock,
)
if is_torchvision_available():
from torchvision import transforms
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class CosmosControlNetOutput(BaseOutput):
"""
Output of [`CosmosControlNetModel`].
Args:
control_block_samples (`list[torch.Tensor]`):
List of control block activations to be injected into transformer blocks.
"""
control_block_samples: List[torch.Tensor]
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
ControlNet for Cosmos Transfer2.5.
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
internally from raw inputs.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]
@register_to_config
def __init__(
self,
n_controlnet_blocks: int = 4,
in_channels: int = 130,
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
model_channels: int = 2048,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
patch_size: Tuple[int, int, int] = (1, 2, 2),
max_size: Tuple[int, int, int] = (128, 240, 240),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
extra_pos_embed_type: Optional[str] = None,
img_context_dim_in: Optional[int] = None,
img_context_dim_out: int = 2048,
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
):
super().__init__()
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
self.time_embed = CosmosEmbedding(model_channels, model_channels)
self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=model_channels,
max_size=max_size,
patch_size=patch_size,
)
self.img_context_proj = None
if img_context_dim_in is not None and img_context_dim_in > 0:
self.img_context_proj = nn.Sequential(
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
nn.GELU(),
)
# Cross-attention projection for text embeddings (same as transformer)
self.crossattn_proj = None
if use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)
# RoPE for both control and base latents
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)
self.control_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
before_proj=(block_idx == 0),
after_proj=True,
)
for block_idx in range(n_controlnet_blocks)
]
)
self.gradient_checkpointing = False
def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]:
if isinstance(conditioning_scale, list):
scales = conditioning_scale
else:
scales = [conditioning_scale] * len(self.control_blocks)
if len(scales) < len(self.control_blocks):
logger.warning(
"Received %d control scales, but control network defines %d blocks. "
"Scales will be trimmed or repeated to match.",
len(scales),
len(self.control_blocks),
)
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
return scales
def forward(
self,
controls_latents: torch.Tensor,
latents: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
condition_mask: torch.Tensor,
conditioning_scale: Union[float, List[float]] = 1.0,
padding_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
return_dict: bool = True,
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
"""
Forward pass for the ControlNet.
Args:
controls_latents: Control signal latents [B, C, T, H, W]
latents: Base latents from the noising process [B, C, T, H, W]
timestep: Diffusion timestep tensor
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
condition_mask: Conditioning mask [B, 1, T, H, W]
conditioning_scale: Scale factor(s) for control outputs
padding_mask: Padding mask [B, 1, H, W] or None
attention_mask: Optional attention mask or None
fps: Frames per second for RoPE or None
return_dict: Whether to return a CosmosControlNetOutput or a tuple
Returns:
CosmosControlNetOutput or tuple of control tensors
"""
B, C, T, H, W = controls_latents.shape
# 1. Prepare control latents
control_hidden_states = controls_latents
vace_in_channels = self.config.in_channels - 1
if control_hidden_states.shape[1] < vace_in_channels - 1:
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
control_hidden_states = torch.cat(
[
control_hidden_states,
torch.zeros(
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
),
],
dim=1,
)
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
control_hidden_states = torch.cat(
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 2. Prepare base latents (same processing as transformer.forward)
base_hidden_states = latents
if condition_mask is not None:
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)
base_padding_mask = transforms.functional.resize(
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
base_hidden_states = torch.cat(
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 3. Generate positional embeddings (shared for both)
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None
# 4. Patchify control latents
control_hidden_states = self.patch_embed(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(1, 3)
# 5. Patchify base latents
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = T // p_t
post_patch_height = H // p_h
post_patch_width = W // p_w
base_hidden_states = self.patch_embed_base(base_hidden_states)
base_hidden_states = base_hidden_states.flatten(1, 3)
# 6. Time embeddings
if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
elif timestep.ndim == 5:
batch_size, _, num_frames, _, _ = latents.shape
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep_flat = timestep.flatten()
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
)
else:
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
# 7. Process encoder hidden states
if isinstance(encoder_hidden_states, tuple):
text_context, img_context = encoder_hidden_states
else:
text_context = encoder_hidden_states
img_context = None
# Apply cross-attention projection to text context
if self.crossattn_proj is not None:
text_context = self.crossattn_proj(text_context)
# Apply cross-attention projection to image context (if provided)
if img_context is not None and self.img_context_proj is not None:
img_context = self.img_context_proj(img_context)
# Combine text and image context into a single tuple
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
processed_encoder_hidden_states = (text_context, img_context)
else:
processed_encoder_hidden_states = text_context
# 8. Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
# 9. Run control blocks
scales = self._expand_conditioning_scale(conditioning_scale)
result = []
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_hidden_states, control_proj = self._gradient_checkpointing_func(
block,
control_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
None, # controlnet_residual
base_hidden_states,
block_idx,
)
else:
control_hidden_states, control_proj = block(
hidden_states=control_hidden_states,
encoder_hidden_states=processed_encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
controlnet_residual=None,
latents=base_hidden_states,
block_idx=block_idx,
)
result.append(control_proj * scale)
if not return_dict:
return (result,)
return CosmosControlNetOutput(control_block_samples=result)

View File

@@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..embeddings import Timesteps
from ..modeling_outputs import Transformer2DModelOutput
@@ -152,7 +152,7 @@ class CosmosAdaLayerNormZero(nn.Module):
class CosmosAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
@@ -191,7 +191,6 @@ class CosmosAttnProcessor2_0:
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
@@ -200,18 +199,148 @@ class CosmosAttnProcessor2_0:
value = value.repeat_interleave(query_idx // value_idx, dim=3)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
hidden_states = dispatch_attention_fn(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
# 6. Output projection
hidden_states = hidden_states.flatten(2, 3).type_as(query)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CosmosAttnProcessor2_5:
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attention_mask: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
image_rotary_emb=None,
) -> torch.Tensor:
if not isinstance(encoder_hidden_states, tuple):
raise ValueError("Expected encoder_hidden_states as (text_context, img_context) tuple.")
text_context, img_context = encoder_hidden_states if encoder_hidden_states else (None, None)
text_mask, img_mask = attention_mask if attention_mask else (None, None)
if text_context is None:
text_context = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(text_context)
value = attn.to_v(text_context)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
attn_out = dispatch_attention_fn(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=text_mask,
dropout_p=0.0,
is_causal=False,
)
attn_out = attn_out.flatten(2, 3).type_as(query)
if img_context is not None:
q_img = attn.q_img(hidden_states)
k_img = attn.k_img(img_context)
v_img = attn.v_img(img_context)
batch_size = hidden_states.shape[0]
dim_head = attn.out_dim // attn.heads
q_img = q_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
k_img = k_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
v_img = v_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
q_img = attn.q_img_norm(q_img)
k_img = attn.k_img_norm(k_img)
q_img_idx = q_img.size(3)
k_img_idx = k_img.size(3)
v_img_idx = v_img.size(3)
k_img = k_img.repeat_interleave(q_img_idx // k_img_idx, dim=3)
v_img = v_img.repeat_interleave(q_img_idx // v_img_idx, dim=3)
img_out = dispatch_attention_fn(
q_img.transpose(1, 2),
k_img.transpose(1, 2),
v_img.transpose(1, 2),
attn_mask=img_mask,
dropout_p=0.0,
is_causal=False,
)
img_out = img_out.flatten(2, 3).type_as(q_img)
hidden_states = attn_out + img_out
else:
hidden_states = attn_out
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CosmosAttention(Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# add parameters for image q/k/v
inner_dim = self.heads * self.to_q.out_features // self.heads
self.q_img = nn.Linear(self.query_dim, inner_dim, bias=False)
self.k_img = nn.Linear(self.query_dim, inner_dim, bias=False)
self.v_img = nn.Linear(self.query_dim, inner_dim, bias=False)
self.q_img_norm = RMSNorm(self.to_q.out_features // self.heads, eps=1e-6, elementwise_affine=True)
self.k_img_norm = RMSNorm(self.to_k.out_features // self.heads, eps=1e-6, elementwise_affine=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
return super().forward(
hidden_states=hidden_states,
# NOTE: type-hint in base class can be ignored
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
class CosmosTransformerBlock(nn.Module):
def __init__(
self,
@@ -222,12 +351,16 @@ class CosmosTransformerBlock(nn.Module):
adaln_lora_dim: int = 256,
qk_norm: str = "rms_norm",
out_bias: bool = False,
img_context: bool = False,
before_proj: bool = False,
after_proj: bool = False,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.img_context = img_context
self.attn1 = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
@@ -240,30 +373,58 @@ class CosmosTransformerBlock(nn.Module):
)
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.attn2 = Attention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
if img_context:
self.attn2 = CosmosAttention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_5(),
)
else:
self.attn2 = Attention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
# NOTE: zero conv for CosmosControlNet
self.before_proj = None
self.after_proj = None
if before_proj:
self.before_proj = nn.Linear(hidden_size, hidden_size)
if after_proj:
self.after_proj = nn.Linear(hidden_size, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: Union[
Optional[torch.Tensor], Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]
],
embedded_timestep: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
extra_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
controlnet_residual: Optional[torch.Tensor] = None,
latents: Optional[torch.Tensor] = None,
block_idx: Optional[int] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.before_proj is not None:
hidden_states = self.before_proj(hidden_states) + latents
if extra_pos_emb is not None:
hidden_states = hidden_states + extra_pos_emb
@@ -284,6 +445,16 @@ class CosmosTransformerBlock(nn.Module):
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate * ff_output
if controlnet_residual is not None:
assert self.after_proj is None
# NOTE: this is assumed to be scaled by the controlnet
hidden_states += controlnet_residual
if self.after_proj is not None:
assert controlnet_residual is None
hs_proj = self.after_proj(hidden_states)
return hidden_states, hs_proj
return hidden_states
@@ -416,6 +587,17 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Whether to concatenate the padding mask to the input latent tensors.
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
controlnet_block_every_n (`int`, *optional*):
Interval between transformer blocks that should receive control residuals (for example, `7` to inject after
every seventh block). Required for Cosmos Transfer2.5.
img_context_dim_in (`int`, *optional*):
The dimension of the input image context feature vector, i.e. it is the D in [B, N, D].
img_context_num_tokens (`int`):
The number of tokens in the image context feature vector, i.e. it is the N in [B, N, D]. If
`img_context_dim_in` is not provided, then this parameter is ignored.
img_context_dim_out (`int`):
The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then
this parameter is ignored.
"""
_supports_gradient_checkpointing = True
@@ -442,6 +624,10 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
controlnet_block_every_n: Optional[int] = None,
img_context_dim_in: Optional[int] = None,
img_context_num_tokens: int = 256,
img_context_dim_out: int = 2048,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
@@ -477,6 +663,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0,
)
for _ in range(num_layers)
]
@@ -496,17 +683,24 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.gradient_checkpointing = False
if self.config.img_context_dim_in:
self.img_context_proj = nn.Sequential(
nn.Linear(self.config.img_context_dim_in, self.config.img_context_dim_out, bias=True),
nn.GELU(),
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
block_controlnet_hidden_states: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
condition_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> torch.Tensor:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
# 1. Concatenate padding mask if needed & prepare attention mask
@@ -514,11 +708,11 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
if self.config.concat_padding_mask:
padding_mask = transforms.functional.resize(
padding_mask_resized = transforms.functional.resize(
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
[hidden_states, padding_mask_resized.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
)
if attention_mask is not None:
@@ -554,36 +748,59 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for x in (temb, embedded_timestep)
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
else:
assert False
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
# 5. Process encoder hidden states
text_context, img_context = (
encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None)
)
if self.config.use_crossattn_projection:
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
text_context = self.crossattn_proj(text_context)
# 5. Transformer blocks
for block in self.transformer_blocks:
if img_context is not None and self.config.img_context_dim_in:
img_context = self.img_context_proj(img_context)
processed_encoder_hidden_states = (
(text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context
)
# 6. Build controlnet block index map
controlnet_block_index_map = {}
if block_controlnet_hidden_states is not None:
n_blocks = len(self.transformer_blocks)
controlnet_block_index_map = {
block_idx: block_controlnet_hidden_states[idx]
for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n)))
}
# 7. Transformer blocks
for block_idx, block in enumerate(self.transformer_blocks):
controlnet_residual = controlnet_block_index_map.get(block_idx)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
controlnet_residual,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
controlnet_residual,
)
# 6. Output norm & projection & unpatchify
# 8. Output norm & projection & unpatchify
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))

View File

@@ -167,6 +167,7 @@ else:
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = [
"Cosmos2_5_PredictBasePipeline",
"Cosmos2_5_TransferPipeline",
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
@@ -631,6 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .cosmos import (
Cosmos2_5_PredictBasePipeline,
Cosmos2_5_TransferPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,

View File

@@ -25,6 +25,9 @@ else:
_import_structure["pipeline_cosmos2_5_predict"] = [
"Cosmos2_5_PredictBasePipeline",
]
_import_structure["pipeline_cosmos2_5_transfer"] = [
"Cosmos2_5_TransferPipeline",
]
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
@@ -41,6 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_cosmos2_5_predict import (
Cosmos2_5_PredictBasePipeline,
)
from .pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline

View File

@@ -0,0 +1,923 @@
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosControlNetModel, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
n_pad_frames = num_frames - video.shape[2]
if n_pad_frames > 0:
last_frame = video[:, :, -1:, :, :]
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import cv2
>>> import numpy as np
>>> import torch
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
... )
>>> pipe = pipe.to("cuda")
>>> # Video2World with edge control: Generate video guided by edge maps extracted from input video.
>>> prompt = (
... "The video is a demonstration of robotic manipulation, likely in a laboratory or testing environment. It"
... "features two robotic arms interacting with a piece of blue fabric. The setting is a room with a beige"
... "couch in the background, providing a neutral backdrop for the robotic activity. The robotic arms are"
... "positioned on either side of the fabric, which is placed on a yellow cushion. The left robotic arm is"
... "white with a black gripper, while the right arm is black with a more complex, articulated gripper. At the"
... "beginning, the fabric is laid out on the cushion. The left robotic arm approaches the fabric, its gripper"
... "opening and closing as it positions itself. The right arm remains stationary initially, poised to assist."
... "As the video progresses, the left arm grips the fabric, lifting it slightly off the cushion. The right arm"
... "then moves in, its gripper adjusting to grasp the opposite side of the fabric. Both arms work in"
... "coordination, lifting and holding the fabric between them. The fabric is manipulated with precision,"
... "showcasing the dexterity and control of the robotic arms. The camera remains static throughout, focusing"
... "on the interaction between the robotic arms and the fabric, allowing viewers to observe the detailed"
... "movements and coordination involved in the task."
... )
>>> negative_prompt = (
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
... "Overall, the video is of poor quality."
... )
>>> input_video = load_video(
... "https://github.com/nvidia-cosmos/cosmos-transfer2.5/raw/refs/heads/main/assets/robot_example/robot_input.mp4"
... )
>>> num_frames = 93
>>> # Extract edge maps from the input video using Canny edge detection
>>> edge_maps = [
... cv2.Canny(cv2.cvtColor(np.array(frame.convert("RGB")), cv2.COLOR_RGB2BGR), 100, 200)
... for frame in input_video[:num_frames]
... ]
>>> edge_maps = np.stack(edge_maps)[None] # (T, H, W) -> (1, T, H, W)
>>> controls = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) # (1, T, H, W) -> (3, T, H, W)
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
>>> video = pipe(
... video=input_video[:num_frames],
... controls=controls,
... controls_conditioning_scale=1.0,
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_frames=num_frames,
... ).frames[0]
>>> export_to_video(video, "edge_controlled_video.mp4", fps=30)
```
"""
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
r"""
Pipeline for Cosmos Transfer2.5 base model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
Frozen text-encoder. Cosmos Transfer2.5 uses the [Qwen2.5
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
tokenizer (`AutoTokenizer`):
Tokenizer associated with the Qwen2.5 VL encoder.
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker", "controlnet"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: AutoTokenizer,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
controlnet: Optional[CosmosControlNetModel],
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
if getattr(self.vae.config, "latents_mean", None) is not None
else None
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
if getattr(self.vae.config, "latents_std", None) is not None
else None
)
self.latents_mean = latents_mean
self.latents_std = latents_std
if self.latents_mean is None or self.latents_std is None:
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
def _get_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
input_ids_batch = []
for sample_idx in range(len(prompt)):
conversations = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant who will provide prompts to an image generator.",
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt[sample_idx],
}
],
},
]
input_ids = self.tokenizer.apply_chat_template(
conversations,
tokenize=True,
add_generation_prompt=False,
add_vision_id=False,
max_length=max_sequence_length,
truncation=True,
padding="max_length",
)
input_ids = torch.LongTensor(input_ids)
input_ids_batch.append(input_ids)
input_ids_batch = torch.stack(input_ids_batch, dim=0)
outputs = self.text_encoder(
input_ids_batch.to(device),
output_hidden_states=True,
)
hidden_states = outputs.hidden_states
normalized_hidden_states = []
for layer_idx in range(1, len(hidden_states)):
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
)
normalized_hidden_states.append(normalized_state)
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
def prepare_latents(
self,
video: Optional[torch.Tensor],
batch_size: int,
num_channels_latents: int = 16,
height: int = 704,
width: int = 1280,
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
B = batch_size
C = num_channels_latents
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
H = height // self.vae_scale_factor_spatial
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
cond_latents = torch.zeros_like(latents)
return (
latents,
cond_latents,
cond_mask,
cond_indicator,
)
else:
if video is None:
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
cond_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
for i in range(batch_size)
]
else:
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
latents,
cond_latents,
cond_mask,
cond_indicator,
)
def _encode_controls(
self,
controls: Optional[torch.Tensor],
height: int,
width: int,
num_frames: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
) -> Optional[torch.Tensor]:
if controls is None:
return None
control_video = self.video_processor.preprocess_video(controls, height, width)
control_video = _maybe_pad_video(control_video, num_frames)
control_video = control_video.to(device=device, dtype=self.vae.dtype)
control_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
]
control_latents = torch.cat(control_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
control_latents = (control_latents - latents_mean) / latents_std
return control_latents
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
height: int = 704,
width: Optional[int] = None,
num_frames: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
controls_conditioning_scale: Union[float, List[float]] = 1.0,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
):
r"""
The call function to the pipeline for generation. Supports three modes:
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image. If not provided, this will be determined based on the
aspect ratio of the input and the provided height.
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `3.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if width is None:
frame = image or video[0] if image or video else None
if frame is None and controls is not None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
frame = controls[0]
if frame is None:
width = int((height + 16) * (1280 / 720))
elif isinstance(frame, PIL.Image.Image):
width = int((height + 16) * (frame.width / frame.height))
else:
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
# Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
video = _maybe_pad_video(video, num_frames_out)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
controls_latents = None
if controls is not None:
controls_latents = self._encode_controls(
controls,
height=height,
width=width,
num_frames=num_frames,
dtype=transformer_dtype,
device=device,
generator=generator,
)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
condition_mask=cond_mask,
conditioning_scale=controls_conditioning_scale,
padding_mask=padding_mask,
return_dict=False,
)
control_blocks = control_output[0]
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
condition_mask=cond_mask,
conditioning_scale=controls_conditioning_scale,
padding_mask=padding_mask,
return_dict=False,
)
control_blocks = control_output[0]
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
assert self.safety_checker is not None
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
if vid is None:
video_batch.append(np.zeros_like(video[0]))
else:
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -658,12 +658,7 @@ class GlmImagePipeline(DiffusionPipeline):
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and prior_token_ids is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prior_token_ids is None:
if prompt is None and prior_token_ids is None:
raise ValueError(
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
)
@@ -694,8 +689,8 @@ class GlmImagePipeline(DiffusionPipeline):
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
)
if prior_token_ids is not None and prompt_embeds is None:
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")
@property
def guidance_scale(self):

View File

@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -51,13 +51,15 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -94,19 +96,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
sigma_schedule: Literal["karras", "exponential"] = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
rho: float = 7.0,
solver_order: int = 2,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
algorithm_type: Literal["dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", # "zero", "sigma_min"
):
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
@@ -145,19 +147,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> float:
# standard deviation of the initial noise distribution
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> int:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> int:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
@@ -274,7 +276,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Optional[Union[str, torch.device]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -460,13 +466,12 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma
return alpha_t, sigma_t
def convert_model_output(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
sample: torch.Tensor,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -497,7 +502,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
sample: torch.Tensor,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
@@ -508,6 +513,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor to add to the original samples.
Returns:
`torch.Tensor`:
@@ -538,7 +545,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
sample: torch.Tensor,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
@@ -549,6 +556,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor to add to the original samples.
Returns:
`torch.Tensor`:
@@ -609,7 +618,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
sample: torch.Tensor,
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
@@ -698,7 +707,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
@@ -719,7 +728,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -860,5 +869,5 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -102,12 +102,21 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
time_shift_type: Literal["exponential", "linear"] = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
@@ -166,6 +175,13 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
def set_shift(self, shift: float):
"""
Sets the shift value for the scheduler.
Args:
shift (`float`):
The shift value to be set.
"""
self._shift = shift
def scale_noise(
@@ -218,10 +234,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return sample
def _sigma_to_t(self, sigma):
def _sigma_to_t(self, sigma) -> float:
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""
Apply time shifting to the sigmas.
Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.
Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
@@ -302,7 +333,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
self._sigma_to_t(self.sigma_max),
self._sigma_to_t(self.sigma_min),
num_inference_steps,
)
sigmas = timesteps / self.config.num_train_timesteps
else:
@@ -350,7 +383,24 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self,
timestep: Union[float, torch.FloatTensor],
schedule_timesteps: Optional[torch.FloatTensor] = None,
) -> int:
"""
Get the index for the given timestep.
Args:
timestep (`float` or `torch.FloatTensor`):
The timestep to find the index for.
schedule_timesteps (`torch.FloatTensor`, *optional*):
The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.
Returns:
`int`:
The index of the timestep.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -364,7 +414,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return indices[pos].item()
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -405,7 +455,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
@@ -474,7 +524,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -595,11 +645,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
return sigmas
def _time_shift_exponential(self, mu, sigma, t):
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def _time_shift_linear(self, mu, sigma, t):
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -482,6 +482,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
"""
Apply time shifting to the sigmas.
Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.
Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":

View File

@@ -896,6 +896,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class CosmosControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CosmosTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -977,6 +977,21 @@ class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2_5_TransferPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,255 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CosmosControlNetModel
from diffusers.models.controlnets.controlnet_cosmos import CosmosControlNetOutput
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosControlNetModel
main_input_name = "controls_latents"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 16
num_frames = 1
height = 16
width = 16
text_embed_dim = 32
sequence_length = 12
img_context_dim_in = 32
img_context_num_tokens = 4
# Raw latents (not patchified) - the controlnet computes embeddings internally
controls_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.tensor([0.5]).to(torch_device) # Diffusion timestep
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
# Text embeddings
text_context = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
# Image context for Cosmos 2.5
img_context = torch.randn((batch_size, img_context_num_tokens, img_context_dim_in)).to(torch_device)
encoder_hidden_states = (text_context, img_context)
return {
"controls_latents": controls_latents,
"latents": latents,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"condition_mask": condition_mask,
"conditioning_scale": 1.0,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (16, 1, 16, 16)
@property
def output_shape(self):
# Output is tuple of n_controlnet_blocks tensors, each with shape (batch, num_patches, model_channels)
# After stacking by normalize_output: (n_blocks, batch, num_patches, model_channels)
# For test config: n_blocks=2, num_patches=64 (1*8*8), model_channels=32
# output_shape is used as (batch_size,) + output_shape, so: (2, 64, 32)
return (2, 64, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"n_controlnet_blocks": 2,
"in_channels": 16 + 1 + 1, # control_latent_channels + condition_mask + padding_mask
"latent_channels": 16 + 1 + 1, # base_latent_channels (16) + condition_mask (1) + padding_mask (1) = 18
"model_channels": 32,
"num_attention_heads": 2,
"attention_head_dim": 16,
"mlp_ratio": 2,
"text_embed_dim": 32,
"adaln_lora_dim": 4,
"patch_size": (1, 2, 2),
"max_size": (4, 32, 32),
"rope_scale": (2.0, 1.0, 1.0),
"extra_pos_embed_type": None,
"img_context_dim_in": 32,
"img_context_dim_out": 32,
"use_crossattn_projection": False, # Test doesn't need this projection
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output_format(self):
"""Test that the model outputs CosmosControlNetOutput with correct structure."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsInstance(output, CosmosControlNetOutput)
self.assertIsInstance(output.control_block_samples, list)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
for tensor in output.control_block_samples:
self.assertIsInstance(tensor, torch.Tensor)
def test_output_list_format(self):
"""Test that return_dict=False returns a tuple containing a list."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict, return_dict=False)
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 1)
self.assertIsInstance(output[0], list)
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
def test_conditioning_scale_single(self):
"""Test that a single conditioning scale is broadcast to all blocks."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_dict["conditioning_scale"] = 0.5
with torch.no_grad():
output = model(**inputs_dict)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_conditioning_scale_list(self):
"""Test that a list of conditioning scales is applied per block."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# Provide a scale for each block
inputs_dict["conditioning_scale"] = [0.5, 1.0]
with torch.no_grad():
output = model(**inputs_dict)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_forward_with_none_img_context(self):
"""Test forward pass when img_context is None."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# Set encoder_hidden_states to (text_context, None)
text_context = inputs_dict["encoder_hidden_states"][0]
inputs_dict["encoder_hidden_states"] = (text_context, None)
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsInstance(output, CosmosControlNetOutput)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_forward_without_img_context_proj(self):
"""Test forward pass when img_context_proj is not configured."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Disable img_context_proj
init_dict["img_context_dim_in"] = None
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# When img_context is disabled, pass only text context (not a tuple)
text_context = inputs_dict["encoder_hidden_states"][0]
inputs_dict["encoder_hidden_states"] = text_context
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsInstance(output, CosmosControlNetOutput)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosControlNetModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Note: test_set_attn_processor_for_determinism already handles uses_custom_attn_processor=True
# so no explicit skip needed for it
# Note: test_forward_signature and test_set_default_attn_processor don't exist in base class
# Skip tests that don't apply to this architecture
@unittest.skip("CosmosControlNetModel doesn't use norm groups.")
def test_forward_with_norm_groups(self):
pass
# Skip tests that expect .sample attribute - ControlNets don't have this
@unittest.skip("ControlNet output doesn't have .sample attribute")
def test_effective_gradient_checkpointing(self):
pass
# Skip tests that compute MSE loss against single tensor output
@unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss")
def test_ema_training(self):
pass
@unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss")
def test_training(self):
pass
# Skip tests where output shape comparison doesn't apply to ControlNets
@unittest.skip("ControlNet output shape doesn't match input shape by design")
def test_output(self):
pass
# Skip outputs_equivalence - dict/list comparison logic not compatible (recursive_check expects dict.values())
@unittest.skip("ControlNet output structure not compatible with recursive dict check")
def test_outputs_equivalence(self):
pass
# Skip model parallelism - base test uses torch.allclose(base_output[0], new_output[0]) which fails
# because output[0] is the list of control_block_samples, not a tensor
@unittest.skip("test_model_parallelism uses torch.allclose on output[0] which is a list, not a tensor")
def test_model_parallelism(self):
pass
# Skip layerwise casting tests - these have two issues:
# 1. _inference and _memory: dtype compatibility issues with learnable_pos_embed and float8/bfloat16
# 2. _training: same as test_training - mse_loss expects tensor, not list
@unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed")
def test_layerwise_casting_inference(self):
pass
@unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed")
def test_layerwise_casting_memory(self):
pass
@unittest.skip("test_layerwise_casting_training computes mse_loss on list output")
def test_layerwise_casting_training(self):
pass

View File

@@ -6,7 +6,7 @@ import pytest
import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
@@ -598,3 +598,68 @@ class TestModularModelCardContent:
content = generate_modular_model_card_content(blocks)
assert "5-block architecture" in content["model_description"]
class TestAutoModelLoadIdTagging:
def test_automodel_tags_load_id(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
assert hasattr(model, "_diffusers_load_id"), "Model should have _diffusers_load_id attribute"
assert model._diffusers_load_id != "null", "_diffusers_load_id should not be 'null'"
# Verify load_id contains the expected fields
load_id = model._diffusers_load_id
assert "hf-internal-testing/tiny-stable-diffusion-xl-pipe" in load_id
assert "unet" in load_id
def test_automodel_update_components(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
auto_model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
pipe.update_components(unet=auto_model)
assert pipe.unet is auto_model
assert "unet" in pipe._component_specs
spec = pipe._component_specs["unet"]
assert spec.pretrained_model_name_or_path == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
assert spec.subfolder == "unet"
class TestLoadComponentsSkipBehavior:
def test_load_components_skips_already_loaded(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
original_unet = pipe.unet
pipe.load_components()
# Verify that the unet is the same object (not reloaded)
assert pipe.unet is original_unet, "load_components should skip already loaded components"
def test_load_components_selective_loading(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(names="unet", torch_dtype=torch.float32)
# Verify only requested component was loaded.
assert hasattr(pipe, "unet")
assert pipe.unet is not None
assert getattr(pipe, "vae", None) is None
def test_load_components_skips_invalid_pretrained_path(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe._component_specs["test_component"] = ComponentSpec(
name="test_component",
type_hint=torch.nn.Module,
pretrained_model_name_or_path=None,
default_creation_method="from_pretrained",
)
pipe.load_components(torch_dtype=torch.float32)
# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None

View File

@@ -0,0 +1,386 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import torch
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLWan,
Cosmos2_5_TransferPipeline,
CosmosControlNetModel,
CosmosTransformer3DModel,
UniPCMultistepScheduler,
)
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
safety_checker = DummyCosmosSafetyChecker()
device_map = kwargs.get("device_map", "cpu")
torch_dtype = kwargs.get("torch_dtype")
if device_map is not None or torch_dtype is not None:
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
kwargs["safety_checker"] = safety_checker
return Cosmos2_5_TransferPipeline.from_pretrained(*args, **kwargs)
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2_5_TransferWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
# Transformer with img_context support for Transfer2.5
transformer = CosmosTransformer3DModel(
in_channels=16 + 1,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
controlnet_block_every_n=1,
img_context_dim_in=32,
img_context_num_tokens=4,
img_context_dim_out=32,
)
torch.manual_seed(0)
controlnet = CosmosControlNetModel(
n_controlnet_blocks=2,
in_channels=16 + 1 + 1, # control latent channels + condition_mask + padding_mask
latent_channels=16 + 1 + 1, # base latent channels (16) + condition_mask (1) + padding_mask (1) = 18
model_channels=32,
num_attention_heads=2,
attention_head_dim=16,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
patch_size=(1, 2, 2),
max_size=(4, 32, 32),
rope_scale=(2.0, 1.0, 1.0),
extra_pos_embed_type="learnable", # Match transformer's config
img_context_dim_in=32,
img_context_dim_out=32,
use_crossattn_projection=False, # Test doesn't need this projection
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = UniPCMultistepScheduler()
torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"controlnet": controlnet,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"num_frames": 3,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_with_controls(self):
"""Test inference with control inputs (ControlNet)."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
# Add control video input - should be a video tensor
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
inputs["controls_conditioning_scale"] = 1.0
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
_ = pipe(**inputs)[0]
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not getattr(self, "test_attention_slicing", True):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
# Remove components that aren't saved as standard diffusers models
if "safety_checker" in model_components:
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
# Skip components that are not loaded from disk or have special handling
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -281,6 +281,86 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Should return 4 images (2 prompts × 2 images per prompt)
self.assertEqual(len(images), 4)
def test_prompt_with_prior_token_ids(self):
"""Test that prompt and prior_token_ids can be provided together.
When both are given, the AR generation step is skipped (prior_token_ids is used
directly) and prompt is used to generate prompt_embeds via the glyph encoder.
"""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
height, width = 32, 32
# Step 1: Run with prompt only to get prior_token_ids from AR model
generator = torch.Generator(device=device).manual_seed(0)
prior_token_ids, _, _ = pipe.generate_prior_tokens(
prompt="A photo of a cat",
height=height,
width=width,
device=torch.device(device),
generator=torch.Generator(device=device).manual_seed(0),
)
# Step 2: Run with both prompt and prior_token_ids — should not raise
generator = torch.Generator(device=device).manual_seed(0)
inputs_both = {
"prompt": "A photo of a cat",
"prior_token_ids": prior_token_ids,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
images = pipe(**inputs_both).images
self.assertEqual(len(images), 1)
self.assertEqual(images[0].shape, (3, 32, 32))
def test_check_inputs_rejects_invalid_combinations(self):
"""Test that check_inputs correctly rejects invalid input combinations."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
height, width = 32, 32
# Neither prompt nor prior_token_ids → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt=None,
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=torch.randn(1, 16, 32),
)
# prior_token_ids alone without prompt or prompt_embeds → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt=None,
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prior_token_ids=torch.randint(0, 100, (1, 64)),
)
# prompt + prompt_embeds together → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt="A cat",
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=torch.randn(1, 16, 32),
)
@unittest.skip("Needs to be revisited.")
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -2406,7 +2406,11 @@ class PipelineTesterMixin:
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
# `component.device` prints the `onload_device` type. We should probably override the
# `device` property in `ModelMixin`.
component_device = next(component.parameters())[0].device
# Skip modules with no parameters (e.g., dummy safety checkers with only buffers)
params = list(component.parameters())
if not params:
continue
component_device = params[0].device
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
@require_torch_accelerator