Compare commits

...

5 Commits

Author SHA1 Message Date
Sayak Paul
7f1334783f Merge branch 'main' into attention-backend-tests 2025-12-22 09:40:11 +05:30
MatrixTeam-AI
262ce19bff Feature: Add Mambo-G Guidance as Guider (#12862)
* Feature: Add Mambo-G Guidance to Qwen-Image Pipeline

* change to guider implementation

* fix copied code residual

* Update src/diffusers/guiders/magnitude_aware_guidance.py

* Apply style fixes

---------

Co-authored-by: Pscgylotti <pscgylotti@github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-19 13:10:40 -10:00
YiYi Xu
f7753b1bc8 more update in modular (#12560)
* move node registry to mellon

* up

* fix

* modula rpipeline update: filter out none for input_names, fix default blocks for pipe.init() and allow user pass additional kwargs_type in a dict

* qwen modular refactor, unpack before decode

* update mellon node config, adding* to required_inputs and required_model_inputs

* modularpipeline.from_pretrained: error out if no config found

* add a component_names property to modular blocks to be consistent!

* flux image_encoder -> vae_encoder

* controlnet_bundle

* refator MellonNodeConfig MellonPipelineConfig

* refactor & simplify mellon utils

* vae_image_encoder -> vae_encoder

* mellon config save keep key order

* style + copies

* add kwargs input for zimage
2025-12-18 19:25:20 -10:00
Miguel Martin
b5309683cb Cosmos Predict2.5 Base: inference pipeline, scheduler & chkpt conversion (#12852)
* cosmos predict2.5 base: convert chkpt & pipeline
- New scheduler: scheduling_flow_unipc_multistep.py
- Changes to TransformerCosmos for text embeddings via crossattn_proj

* scheduler cleanup

* simplify inference pipeline

* cleanup scheduler + tests

* Basic tests for flow unipc

* working b2b inference

* Rename everything

* Tests for pipeline present, but not working (predict2 also not working)

* docstring update

* wrapper pipelines + make style

* remove unnecessary files

* UniPCMultistep: support use_karras_sigmas=True and use_flow_sigmas=True

* use UniPCMultistepScheduler + fix tests for pipeline

* Remove FlowUniPCMultistepScheduler

* UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=True

* num_inference_steps=36 due to bug in scheduler used by predict2.5

* Address comments

* make style + make fix-copies

* fix tests + remove references to old pipelines

* address comments

* add revision in from_pretrained call

* fix tests
2025-12-19 05:38:18 +05:30
sayakpaul
cc5eaf0429 enhance attention backend tests 2025-12-12 13:25:34 +05:30
25 changed files with 2172 additions and 902 deletions

View File

@@ -70,6 +70,12 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -1,11 +1,55 @@
"""
# Cosmos 2 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2-2B-Text2Image
```
convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_ckpt_path $transformer_ckpt_path \
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
--text_encoder_path google-t5/t5-11b \
--tokenizer_path google-t5/t5-11b \
--vae_type wan2.1 \
--output_path converted/cosmos-p2-t2i-2b \
--save_pipeline
```
# Cosmos 2.5 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2.5-2B
```
Convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/cosmos-p2.5-base-2b \
--save_pipeline
```
"""
import argparse
import pathlib
import sys
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
from diffusers import (
AutoencoderKLCosmos,
@@ -17,7 +61,9 @@ from diffusers import (
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -233,6 +279,25 @@ TRANSFORMER_CONFIGS = {
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.5-Predict-Base-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}
VAE_KEYS_RENAME_DICT = {
@@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
elif "Cosmos-2.5" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False
@@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
@@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
continue
handler_fn_inplace(key, original_state_dict)
expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
for k in missing_keys:
print(k)
sys.exit(1)
if unexpected_keys:
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
for k in unexpected_keys:
print(k)
sys.exit(2)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
@@ -444,6 +528,34 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)
pipe = Cosmos2_5_PredictBasePipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
@@ -451,10 +563,10 @@ def get_args():
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument(
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--text_encoder_path", type=str, default=None)
parser.add_argument("--tokenizer_path", type=str, default=None)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -477,8 +589,6 @@ if __name__ == "__main__":
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -490,17 +600,26 @@ if __name__ == "__main__":
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
else:
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
else:
raise AssertionError(f"{args.transformer_type} not supported")
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_1_0(args, transformer, vae)
elif "Cosmos-2.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
else:
assert False
raise AssertionError(f"{args.transformer_type} not supported")

View File

@@ -463,6 +463,7 @@ else:
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
@@ -1175,6 +1176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,

View File

@@ -25,6 +25,7 @@ if is_torch_available():
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .magnitude_aware_guidance import MagnitudeAwareGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance

View File

@@ -0,0 +1,159 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class MagnitudeAwareGuidance(BaseGuidance):
"""
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
Args:
guidance_scale (`float`, defaults to `10.0`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
alpha (`float`, defaults to `8.0`):
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
guidance scale when the magnitude of the guidance update is large.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 10.0,
alpha: float = 8.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.alpha = alpha
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_mambo_g_enabled():
pred = pred_cond
else:
pred = mambo_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.alpha,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_mambo_g_enabled():
num_conditions += 1
return num_conditions
def _is_mambo_g_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def mambo_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
alpha: float = 8.0,
use_original_formulation: bool = False,
):
dim = list(range(1, len(pred_cond.shape)))
diff = pred_cond - pred_uncond
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
guidance_scale_final = (
guidance_scale * torch.exp(-alpha * ratio)
if use_original_formulation
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
)
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale_final * diff
return pred

View File

@@ -439,6 +439,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
concat_padding_mask: bool = True,
extra_pos_embed_type: Optional[str] = "learnable",
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
@@ -485,6 +488,12 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
)
if self.config.use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)
self.gradient_checkpointing = False
def forward(
@@ -524,6 +533,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
@@ -546,6 +556,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
else:
assert False
if self.config.use_crossattn_projection:
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
# 5. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:

View File

@@ -360,7 +360,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("image_encoder", FluxAutoVaeEncoderStep()),
("vae_encoder", FluxAutoVaeEncoderStep()),
("denoise", FluxCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]
@@ -369,7 +369,7 @@ AUTO_BLOCKS = InsertableDict(
AUTO_BLOCKS_KONTEXT = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("image_encoder", FluxKontextAutoVaeEncoderStep()),
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
("denoise", FluxKontextCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]

File diff suppressed because it is too large Load Diff

View File

@@ -501,15 +501,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs]
return [input_param.name for input_param in self.inputs if input_param.name is not None]
@property
def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs]
return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None]
@property
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]
return [output_param.name for output_param in self.outputs if output_param.name is not None]
@property
def component_names(self) -> List[str]:
return [component.name for component in self.expected_components]
@property
def doc(self):
@@ -1525,10 +1529,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if blocks is None:
if modular_config_dict is not None:
blocks_class_name = modular_config_dict.get("_blocks_class_name")
elif config_dict is not None:
blocks_class_name = self.get_default_blocks_name(config_dict)
else:
blocks_class_name = None
blocks_class_name = self.get_default_blocks_name(config_dict)
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1625,7 +1627,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
return None, config_dict
except EnvironmentError as e:
logger.debug(f" model_index.json not found in the repo: {e}")
raise EnvironmentError(
f"Failed to load config from '{pretrained_model_name_or_path}'. "
f"Could not find or load 'modular_model_index.json' or 'model_index.json'."
) from e
return None, None
@@ -2550,7 +2555,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
state.set(name, passed_kwargs.pop(name), kwargs_type)
elif name not in state.values:
elif kwargs_type is not None and kwargs_type in passed_kwargs:
kwargs_dict = passed_kwargs.pop(kwargs_type)
for k, v in kwargs_dict.items():
state.set(k, v, kwargs_type)
elif name is not None and name not in state.values:
state.set(name, default, kwargs_type)
# Warn about unexpected inputs

View File

@@ -30,6 +30,47 @@ from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
logger = logging.get_logger(__name__)
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)"
@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
return components
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to decode, can be generated in the denoise step",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_scale_factor = components.vae_scale_factor
block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
)
self.set_block_state(state, block_state)
return components, state
class QwenImageDecoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -41,7 +82,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
return components
@@ -49,8 +89,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(
name="latents",
required=True,
@@ -74,10 +112,12 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
vae_scale_factor = components.vae_scale_factor
block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
)
if block_state.latents.ndim == 4:
block_state.latents = block_state.latents.unsqueeze(dim=1)
elif block_state.latents.ndim != 5:
raise ValueError(
f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step."
)
block_state.latents = block_state.latents.to(components.vae.dtype)
latents_mean = (

View File

@@ -26,7 +26,12 @@ from .before_denoise import (
QwenImageSetTimestepsStep,
QwenImageSetTimestepsWithStrengthStep,
)
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageInpaintProcessImagesOutputStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageControlNetDenoiseStep,
QwenImageDenoiseStep,
@@ -92,6 +97,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
("denoise", QwenImageDenoiseStep()),
("after_denoise", QwenImageAfterDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
@@ -205,6 +211,7 @@ INPAINT_BLOCKS = InsertableDict(
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
("denoise", QwenImageInpaintDenoiseStep()),
("after_denoise", QwenImageAfterDenoiseStep()),
("decode", QwenImageInpaintDecodeStep()),
]
)
@@ -264,6 +271,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
("denoise", QwenImageDenoiseStep()),
("after_denoise", QwenImageAfterDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
@@ -529,8 +537,16 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
QwenImageAutoBeforeDenoiseStep,
QwenImageOptionalControlNetBeforeDenoiseStep,
QwenImageAutoDenoiseStep,
QwenImageAfterDenoiseStep,
]
block_names = [
"input",
"controlnet_input",
"before_denoise",
"controlnet_before_denoise",
"denoise",
"after_denoise",
]
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
@property
def description(self):
@@ -653,6 +669,7 @@ EDIT_BLOCKS = InsertableDict(
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
("denoise", QwenImageEditDenoiseStep()),
("after_denoise", QwenImageAfterDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
@@ -702,6 +719,7 @@ EDIT_INPAINT_BLOCKS = InsertableDict(
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
("denoise", QwenImageEditInpaintDenoiseStep()),
("after_denoise", QwenImageAfterDenoiseStep()),
("decode", QwenImageInpaintDecodeStep()),
]
)
@@ -841,8 +859,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
QwenImageEditAutoInputStep,
QwenImageEditAutoBeforeDenoiseStep,
QwenImageEditAutoDenoiseStep,
QwenImageAfterDenoiseStep,
]
block_names = ["input", "before_denoise", "denoise"]
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
@property
def description(self):
@@ -954,6 +973,7 @@ EDIT_PLUS_BLOCKS = InsertableDict(
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
("denoise", QwenImageEditDenoiseStep()),
("after_denoise", QwenImageAfterDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
@@ -1037,8 +1057,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
QwenImageEditPlusAutoInputStep,
QwenImageEditPlusAutoBeforeDenoiseStep,
QwenImageEditAutoDenoiseStep,
QwenImageAfterDenoiseStep,
]
block_names = ["input", "before_denoise", "denoise"]
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
@property
def description(self):

View File

@@ -1,95 +0,0 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# mellon nodes
QwenImage_NODE_TYPES_PARAMS_MAP = {
"controlnet": {
"inputs": [
"control_image",
"controlnet_conditioning_scale",
"control_guidance_start",
"control_guidance_end",
"height",
"width",
],
"model_inputs": [
"controlnet",
"vae",
],
"outputs": [
"controlnet_out",
],
"block_names": ["controlnet_vae_encoder"],
},
"denoise": {
"inputs": [
"embeddings",
"width",
"height",
"seed",
"num_inference_steps",
"guidance_scale",
"image_latents",
"strength",
"controlnet",
],
"model_inputs": [
"unet",
"guider",
"scheduler",
],
"outputs": [
"latents",
"latents_preview",
],
"block_names": ["denoise"],
},
"vae_encoder": {
"inputs": [
"image",
"width",
"height",
],
"model_inputs": [
"vae",
],
"outputs": [
"image_latents",
],
},
"text_encoder": {
"inputs": [
"prompt",
"negative_prompt",
],
"model_inputs": [
"text_encoders",
],
"outputs": [
"embeddings",
],
},
"decoder": {
"inputs": [
"latents",
],
"model_inputs": [
"vae",
],
"outputs": [
"images",
],
},
}

View File

@@ -1,99 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
SDXL_NODE_TYPES_PARAMS_MAP = {
"controlnet": {
"inputs": [
"control_image",
"controlnet_conditioning_scale",
"control_guidance_start",
"control_guidance_end",
"height",
"width",
],
"model_inputs": [
"controlnet",
],
"outputs": [
"controlnet_out",
],
"block_names": [None],
},
"denoise": {
"inputs": [
"embeddings",
"width",
"height",
"seed",
"num_inference_steps",
"guidance_scale",
"image_latents",
"strength",
# custom adapters coming in as inputs
"controlnet",
# ip_adapter is optional and custom; include if available
"ip_adapter",
],
"model_inputs": [
"unet",
"guider",
"scheduler",
],
"outputs": [
"latents",
"latents_preview",
],
"block_names": ["denoise"],
},
"vae_encoder": {
"inputs": [
"image",
"width",
"height",
],
"model_inputs": [
"vae",
],
"outputs": [
"image_latents",
],
"block_names": ["vae_encoder"],
},
"text_encoder": {
"inputs": [
"prompt",
"negative_prompt",
],
"model_inputs": [
"text_encoders",
],
"outputs": [
"embeddings",
],
"block_names": ["text_encoder"],
},
"decoder": {
"inputs": [
"latents",
],
"model_inputs": [
"vae",
],
"outputs": [
"images",
],
"block_names": ["decode"],
},
}

View File

@@ -129,6 +129,10 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
]
guider_input_names = []
uncond_guider_input_names = []

View File

@@ -119,7 +119,7 @@ class ZImageAutoDenoiseStep(AutoPipelineBlocks):
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [ZImageVaeImageEncoderStep]
block_names = ["vae_image_encoder"]
block_names = ["vae_encoder"]
block_trigger_inputs = ["image"]
@property
@@ -137,7 +137,7 @@ class ZImageAutoBlocks(SequentialPipelineBlocks):
ZImageAutoDenoiseStep,
ZImageVaeDecoderStep,
]
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self) -> str:
@@ -162,7 +162,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("vae_image_encoder", ZImageVaeImageEncoderStep),
("vae_encoder", ZImageVaeImageEncoderStep),
("input", ZImageTextInputStep),
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
("prepare_latents", ZImagePrepareLatentsStep),
@@ -178,7 +178,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
("vae_encoder", ZImageAutoVaeImageEncoderStep),
("denoise", ZImageAutoDenoiseStep),
("decode", ZImageVaeDecoderStep),
]

View File

@@ -165,6 +165,7 @@ else:
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = [
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
@@ -622,6 +623,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetXSPipeline,
)
from .cosmos import (
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,

View File

@@ -22,6 +22,9 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cosmos2_5_predict"] = [
"Cosmos2_5_PredictBasePipeline",
]
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
@@ -35,6 +38,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cosmos2_5_predict import (
Cosmos2_5_PredictBasePipeline,
)
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline

View File

@@ -0,0 +1,847 @@
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Cosmos2_5_PredictBasePipeline
>>> from diffusers.utils import export_to_video, load_image, load_video
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
... )
>>> pipe = pipe.to("cuda")
>>> # Common negative prompt reused across modes.
>>> negative_prompt = (
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
... "Overall, the video is of poor quality."
... )
>>> # Text2World: generate a 93-frame world video from text only.
>>> prompt = (
... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights "
... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh "
... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet "
... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. "
... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow "
... "advance of traffic through the frosty city corridor."
... )
>>> video = pipe(
... image=None,
... video=None,
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_frames=93,
... generator=torch.Generator().manual_seed(1),
... ).frames[0]
>>> export_to_video(video, "text2world.mp4", fps=16)
>>> # Image2World: condition on a single image and generate a 93-frame world video.
>>> prompt = (
... "A high-definition video captures the precision of robotic welding in an industrial setting. "
... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. "
... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid "
... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring "
... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a "
... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video "
... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. "
... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. "
... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with "
... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation."
... )
>>> image = load_image(
... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg"
... )
>>> video = pipe(
... image=image,
... video=None,
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_frames=93,
... generator=torch.Generator().manual_seed(1),
... ).frames[0]
>>> # export_to_video(video, "image2world.mp4", fps=16)
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
>>> prompt = (
... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles "
... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the "
... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green "
... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. "
... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along "
... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame "
... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet "
... "steady pace of the construction activity."
... )
>>> input_video = load_video(
... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4"
... )
>>> video = pipe(
... image=None,
... video=input_video,
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_frames=93,
... generator=torch.Generator().manual_seed(1),
... ).frames[0]
>>> export_to_video(video, "video2world.mp4", fps=16)
>>> # To produce an image instead of a world (video) clip, set num_frames=1 and
>>> # save the first frame: pipe(..., num_frames=1).frames[0][0].
```
"""
class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
r"""
Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
tokenizer (`AutoTokenizer`):
Tokenizer associated with the Qwen2.5 VL encoder.
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: AutoTokenizer,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
if getattr(self.vae.config, "latents_mean", None) is not None
else None
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
if getattr(self.vae.config, "latents_std", None) is not None
else None
)
self.latents_mean = latents_mean
self.latents_std = latents_std
if self.latents_mean is None or self.latents_std is None:
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
def _get_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
input_ids_batch = []
for sample_idx in range(len(prompt)):
conversations = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant who will provide prompts to an image generator.",
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt[sample_idx],
}
],
},
]
input_ids = self.tokenizer.apply_chat_template(
conversations,
tokenize=True,
add_generation_prompt=False,
add_vision_id=False,
max_length=max_sequence_length,
truncation=True,
padding="max_length",
)
input_ids = torch.LongTensor(input_ids)
input_ids_batch.append(input_ids)
input_ids_batch = torch.stack(input_ids_batch, dim=0)
outputs = self.text_encoder(
input_ids_batch.to(device),
output_hidden_states=True,
)
hidden_states = outputs.hidden_states
normalized_hidden_states = []
for layer_idx in range(1, len(hidden_states)):
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
)
normalized_hidden_states.append(normalized_state)
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
def prepare_latents(
self,
video: Optional[torch.Tensor],
batch_size: int,
num_channels_latents: int = 16,
height: int = 704,
width: int = 1280,
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
B = batch_size
C = num_channels_latents
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
H = height // self.vae_scale_factor_spatial
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
cond_latents = torch.zeros_like(latents)
return (
latents,
cond_latents,
cond_mask,
cond_indicator,
)
else:
if video is None:
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3)
if needs_preprocessing:
video = self.video_processor.preprocess_video(video, height, width)
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
cond_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
for i in range(batch_size)
]
else:
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
latents,
cond_latents,
cond_mask,
cond_indicator,
)
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 704,
width: int = 1280,
num_frames: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 7.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
):
r"""
The call function to the pipeline for generation. Supports three modes:
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `7.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
# Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
if video.shape[2] < num_frames_out:
n_pad_frames = num_frames_out - num_frames_in
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
video = torch.cat((video, pad_frames), dim=2)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
noise_pred = self.transformer(
hidden_states=in_latents,
condition_mask=cond_mask,
timestep=in_timestep,
encoder_hidden_states=prompt_embeds,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
noise_pred_neg = self.transformer(
hidden_states=in_latents,
condition_mask=cond_mask,
timestep=in_timestep,
encoder_hidden_states=negative_prompt_embeds,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
assert self.safety_checker is not None
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -217,6 +217,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: Literal["exponential"] = "exponential",
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -350,7 +352,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
if self.config.use_flow_sigmas:
sigmas = sigmas / (sigmas + 1)
timesteps = (sigmas * self.config.num_train_timesteps).copy()
else:
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":

View File

@@ -767,6 +767,21 @@ class ConsisIDPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -11,8 +11,7 @@ export RUN_ATTENTION_BACKEND_TESTS=yes
pytest tests/others/test_attention_backends.py
```
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
@@ -24,6 +23,8 @@ import os
import pytest
import torch
from ..testing_utils import numpy_cosine_similarity_distance
pytestmark = pytest.mark.skipif(
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
@@ -36,23 +37,28 @@ from diffusers.utils import is_torch_version # noqa: E402
FORWARD_CASES = [
(
"flash_hub",
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
1e-4
),
(
"_flash_3_hub",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
1e-4
),
(
"native",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
),
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
1e-4
),
(
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
5e-4
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
1e-4
)
]
@@ -60,27 +66,32 @@ COMPILE_CASES = [
(
"flash_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True
True,
1e-4
),
(
"_flash_3_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True,
1e-4
),
(
"native",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
True,
1e-4
),
(
"_native_cudnn",
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
5e-4,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
1e-4
)
]
# fmt: on
@@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str):
return False
def _check_if_slices_match(output, expected_slice):
def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
img = output.images.detach().cpu()
generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff
@pytest.fixture(scope="session")
@@ -126,23 +137,23 @@ def pipe(device):
return pipe
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice):
@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice, expected_diff):
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
_check_if_slices_match(out, expected_slice, expected_diff)
@pytest.mark.parametrize(
"backend_name,expected_slice,error_on_recompile",
"backend_name,expected_slice,error_on_recompile,expected_diff",
COMPILE_CASES,
ids=[c[0] for c in COMPILE_CASES],
)
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
@@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
):
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
_check_if_slices_match(out, expected_slice, expected_diff)

View File

@@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
def __init__(self) -> None:
super().__init__()
self._dtype = torch.float32
self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False)
def check_text_safety(self, prompt: str) -> bool:
return True
@@ -35,13 +35,14 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
return frames
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
self._dtype = dtype
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None):
module = super().to(device=device, dtype=dtype)
return module
@property
def device(self) -> torch.device:
return None
return self._device_tracker.device
@property
def dtype(self) -> torch.dtype:
return self._dtype
return self._device_tracker.dtype

View File

@@ -0,0 +1,337 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import torch
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLWan,
Cosmos2_5_PredictBasePipeline,
CosmosTransformer3DModel,
UniPCMultistepScheduler,
)
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
safety_checker = DummyCosmosSafetyChecker()
device_map = kwargs.get("device_map", "cpu")
torch_dtype = kwargs.get("torch_dtype")
if device_map is not None or torch_dtype is not None:
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
kwargs["safety_checker"] = safety_checker
return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs)
class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2_5_PredictBaseWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CosmosTransformer3DModel(
in_channels=16 + 1,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = UniPCMultistepScheduler()
torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"num_frames": 3,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
_ = pipe(**inputs)[0]
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not getattr(self, "test_attention_slicing", True):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -399,3 +399,32 @@ class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
def test_exponential_sigmas(self):
self.check_over_configs(use_exponential_sigmas=True)
def test_flow_and_karras_sigmas(self):
self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True)
def test_flow_and_karras_sigmas_values(self):
num_train_timesteps = 1000
num_inference_steps = 5
scheduler = UniPCMultistepScheduler(
sigma_min=0.01,
sigma_max=200.0,
use_flow_sigmas=True,
use_karras_sigmas=True,
num_train_timesteps=num_train_timesteps,
)
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
expected_sigmas = [
0.9950248599052429,
0.9787454605102539,
0.8774884343147278,
0.3604971766471863,
0.009900986216962337,
0.0, # 0 appended as default
]
expected_sigmas = torch.tensor(expected_sigmas)
expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64)
expected_timesteps = expected_timesteps[0:-1]
self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas))
self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps))

View File

@@ -131,6 +131,10 @@ def torch_all_close(a, b, *args, **kwargs):
def numpy_cosine_similarity_distance(a, b):
if isinstance(a, torch.Tensor):
a = a.detach().cpu().float().numpy()
if isinstance(b, torch.Tensor):
b = b.detach().cpu().float().numpy()
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()