mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 10:54:34 +08:00
Compare commits
2 Commits
modular-10
...
qwenimage-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f70010ca5d | ||
|
|
2f0b35fd84 |
@@ -70,12 +70,6 @@ output.save("output.png")
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
|
||||
@@ -1,55 +1,11 @@
|
||||
"""
|
||||
# Cosmos 2 Predict
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2-2B-Text2Image
|
||||
```
|
||||
|
||||
convert checkpoint
|
||||
```bash
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
|
||||
--text_encoder_path google-t5/t5-11b \
|
||||
--tokenizer_path google-t5/t5-11b \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/cosmos-p2-t2i-2b \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
# Cosmos 2.5 Predict
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-2B
|
||||
```
|
||||
|
||||
Convert checkpoint
|
||||
```bash
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/cosmos-p2.5-base-2b \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLCosmos,
|
||||
@@ -61,9 +17,7 @@ from diffusers import (
|
||||
CosmosVideoToWorldPipeline,
|
||||
EDMEulerScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -279,25 +233,6 @@ TRANSFORMER_CONFIGS = {
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-2B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -399,9 +334,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
elif "Cosmos-2.0" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
elif "Cosmos-2.5" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -415,7 +347,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
print(key, "->", new_key, flush=True)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -424,21 +355,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
expected_keys = set(transformer.state_dict().keys())
|
||||
mapped_keys = set(original_state_dict.keys())
|
||||
missing_keys = expected_keys - mapped_keys
|
||||
unexpected_keys = mapped_keys - expected_keys
|
||||
if missing_keys:
|
||||
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
|
||||
for k in missing_keys:
|
||||
print(k)
|
||||
sys.exit(1)
|
||||
if unexpected_keys:
|
||||
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
|
||||
for k in unexpected_keys:
|
||||
print(k)
|
||||
sys.exit(2)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
@@ -528,34 +444,6 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def save_pipeline_cosmos2_5(args, transformer, vae):
|
||||
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
|
||||
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
text_encoder_path, torch_dtype="auto", device_map="cpu"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
use_karras_sigmas=True,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
sigma_max=200.0,
|
||||
sigma_min=0.01,
|
||||
)
|
||||
|
||||
pipe = Cosmos2_5_PredictBasePipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
safety_checker=lambda *args, **kwargs: None,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
@@ -563,10 +451,10 @@ def get_args():
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
)
|
||||
parser.add_argument("--text_encoder_path", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None)
|
||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
@@ -589,6 +477,8 @@ if __name__ == "__main__":
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None
|
||||
assert args.vae_type is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
weights_only = "Cosmos-1.0" in args.transformer_type
|
||||
@@ -600,26 +490,17 @@ if __name__ == "__main__":
|
||||
if args.vae_type is not None:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
vae = convert_vae(args.vae_type)
|
||||
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_1_0(args, transformer, vae)
|
||||
elif "Cosmos-2.0" in args.transformer_type:
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_2_0(args, transformer, vae)
|
||||
elif "Cosmos-2.5" in args.transformer_type:
|
||||
save_pipeline_cosmos2_5(args, transformer, vae)
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
assert False
|
||||
|
||||
@@ -463,7 +463,6 @@ else:
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"Cosmos2VideoToWorldPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
@@ -1176,7 +1175,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
|
||||
@@ -439,9 +439,6 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
concat_padding_mask: bool = True,
|
||||
extra_pos_embed_type: Optional[str] = "learnable",
|
||||
use_crossattn_projection: bool = False,
|
||||
crossattn_proj_in_channels: int = 1024,
|
||||
encoder_hidden_states_channels: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
@@ -488,12 +485,6 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
|
||||
)
|
||||
|
||||
if self.config.use_crossattn_projection:
|
||||
self.crossattn_proj = nn.Sequential(
|
||||
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
@@ -533,7 +524,6 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
||||
|
||||
@@ -556,9 +546,6 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
else:
|
||||
assert False
|
||||
|
||||
if self.config.use_crossattn_projection:
|
||||
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
|
||||
|
||||
# 5. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -360,7 +360,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxAutoVaeEncoderStep()),
|
||||
("image_encoder", FluxAutoVaeEncoderStep()),
|
||||
("denoise", FluxCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
@@ -369,7 +369,7 @@ AUTO_BLOCKS = InsertableDict(
|
||||
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("image_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("denoise", FluxKontextCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -501,19 +501,15 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
||||
return [input_param.name for input_param in self.inputs]
|
||||
|
||||
@property
|
||||
def intermediate_output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None]
|
||||
return [output_param.name for output_param in self.intermediate_outputs]
|
||||
|
||||
@property
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs if output_param.name is not None]
|
||||
|
||||
@property
|
||||
def component_names(self) -> List[str]:
|
||||
return [component.name for component in self.expected_components]
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
@@ -1529,8 +1525,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if blocks is None:
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
else:
|
||||
elif config_dict is not None:
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
else:
|
||||
blocks_class_name = None
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
@@ -1627,10 +1625,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
return None, config_dict
|
||||
|
||||
except EnvironmentError as e:
|
||||
raise EnvironmentError(
|
||||
f"Failed to load config from '{pretrained_model_name_or_path}'. "
|
||||
f"Could not find or load 'modular_model_index.json' or 'model_index.json'."
|
||||
) from e
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
return None, None
|
||||
|
||||
@@ -2555,11 +2550,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||
elif kwargs_type is not None and kwargs_type in passed_kwargs:
|
||||
kwargs_dict = passed_kwargs.pop(kwargs_type)
|
||||
for k, v in kwargs_dict.items():
|
||||
state.set(k, v, kwargs_type)
|
||||
elif name is not None and name not in state.values:
|
||||
elif name not in state.values:
|
||||
state.set(name, default, kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
|
||||
@@ -30,47 +30,6 @@ from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -82,6 +41,7 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
return components
|
||||
@@ -89,6 +49,8 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
@@ -112,12 +74,10 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||
if block_state.latents.ndim == 4:
|
||||
block_state.latents = block_state.latents.unsqueeze(dim=1)
|
||||
elif block_state.latents.ndim != 5:
|
||||
raise ValueError(
|
||||
f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step."
|
||||
)
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||
|
||||
latents_mean = (
|
||||
|
||||
@@ -26,12 +26,7 @@ from .before_denoise import (
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
@@ -97,7 +92,6 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -211,7 +205,6 @@ INPAINT_BLOCKS = InsertableDict(
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageInpaintDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -271,7 +264,6 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -537,16 +529,8 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageAutoBeforeDenoiseStep,
|
||||
QwenImageOptionalControlNetBeforeDenoiseStep,
|
||||
QwenImageAutoDenoiseStep,
|
||||
QwenImageAfterDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"before_denoise",
|
||||
"controlnet_before_denoise",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -669,7 +653,6 @@ EDIT_BLOCKS = InsertableDict(
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -719,7 +702,6 @@ EDIT_INPAINT_BLOCKS = InsertableDict(
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditInpaintDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -859,9 +841,8 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageEditAutoInputStep,
|
||||
QwenImageEditAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
QwenImageAfterDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -973,7 +954,6 @@ EDIT_PLUS_BLOCKS = InsertableDict(
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -1057,9 +1037,8 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageEditPlusAutoInputStep,
|
||||
QwenImageEditPlusAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
QwenImageAfterDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
|
||||
95
src/diffusers/modular_pipelines/qwenimage/node_utils.py
Normal file
95
src/diffusers/modular_pipelines/qwenimage/node_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# mellon nodes
|
||||
QwenImage_NODE_TYPES_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet_out",
|
||||
],
|
||||
"block_names": ["controlnet_vae_encoder"],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
"controlnet",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
SDXL_NODE_TYPES_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet_out",
|
||||
],
|
||||
"block_names": [None],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
# custom adapters coming in as inputs
|
||||
"controlnet",
|
||||
# ip_adapter is optional and custom; include if available
|
||||
"ip_adapter",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
"block_names": ["vae_encoder"],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
"block_names": ["text_encoder"],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
"block_names": ["decode"],
|
||||
},
|
||||
}
|
||||
@@ -119,7 +119,7 @@ class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
|
||||
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [ZImageVaeImageEncoderStep]
|
||||
block_names = ["vae_encoder"]
|
||||
block_names = ["vae_image_encoder"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
@@ -137,7 +137,7 @@ class ZImageAutoBlocks(SequentialPipelineBlocks):
|
||||
ZImageAutoDenoiseStep,
|
||||
ZImageVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -162,7 +162,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_encoder", ZImageVaeImageEncoderStep),
|
||||
("vae_image_encoder", ZImageVaeImageEncoderStep),
|
||||
("input", ZImageTextInputStep),
|
||||
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
||||
("prepare_latents", ZImagePrepareLatentsStep),
|
||||
@@ -178,7 +178,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_encoder", ZImageAutoVaeImageEncoderStep),
|
||||
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
|
||||
("denoise", ZImageAutoDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
|
||||
@@ -165,7 +165,6 @@ else:
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["cosmos"] = [
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
@@ -623,7 +622,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .cosmos import (
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
|
||||
@@ -22,9 +22,6 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cosmos2_5_predict"] = [
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
]
|
||||
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
|
||||
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||
@@ -38,9 +35,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_cosmos2_5_predict import (
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
)
|
||||
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
|
||||
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
|
||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||
|
||||
@@ -1,847 +0,0 @@
|
||||
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_PredictBasePipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||
... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # Common negative prompt reused across modes.
|
||||
>>> negative_prompt = (
|
||||
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
|
||||
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
... "Overall, the video is of poor quality."
|
||||
... )
|
||||
|
||||
>>> # Text2World: generate a 93-frame world video from text only.
|
||||
>>> prompt = (
|
||||
... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights "
|
||||
... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh "
|
||||
... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet "
|
||||
... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. "
|
||||
... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow "
|
||||
... "advance of traffic through the frosty city corridor."
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... image=None,
|
||||
... video=None,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "text2world.mp4", fps=16)
|
||||
|
||||
>>> # Image2World: condition on a single image and generate a 93-frame world video.
|
||||
>>> prompt = (
|
||||
... "A high-definition video captures the precision of robotic welding in an industrial setting. "
|
||||
... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. "
|
||||
... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid "
|
||||
... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring "
|
||||
... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a "
|
||||
... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video "
|
||||
... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. "
|
||||
... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. "
|
||||
... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with "
|
||||
... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation."
|
||||
... )
|
||||
>>> image = load_image(
|
||||
... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg"
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... image=image,
|
||||
... video=None,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> # export_to_video(video, "image2world.mp4", fps=16)
|
||||
|
||||
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
|
||||
>>> prompt = (
|
||||
... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles "
|
||||
... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the "
|
||||
... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green "
|
||||
... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. "
|
||||
... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along "
|
||||
... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame "
|
||||
... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet "
|
||||
... "steady pace of the construction activity."
|
||||
... )
|
||||
>>> input_video = load_video(
|
||||
... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4"
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... image=None,
|
||||
... video=input_video,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "video2world.mp4", fps=16)
|
||||
|
||||
>>> # To produce an image instead of a world (video) clip, set num_frames=1 and
|
||||
>>> # save the first frame: pipe(..., num_frames=1).frames[0][0].
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
|
||||
Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5
|
||||
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
Tokenizer associated with the Qwen2.5 VL encoder.
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: UniPCMultistepScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||
if getattr(self.vae.config, "latents_mean", None) is not None
|
||||
else None
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||
if getattr(self.vae.config, "latents_std", None) is not None
|
||||
else None
|
||||
)
|
||||
self.latents_mean = latents_mean
|
||||
self.latents_std = latents_std
|
||||
|
||||
if self.latents_mean is None or self.latents_std is None:
|
||||
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
|
||||
|
||||
def _get_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
input_ids_batch = []
|
||||
|
||||
for sample_idx in range(len(prompt)):
|
||||
conversations = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant who will provide prompts to an image generator.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt[sample_idx],
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
add_vision_id=False,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
input_ids_batch = torch.stack(input_ids_batch, dim=0)
|
||||
|
||||
outputs = self.text_encoder(
|
||||
input_ids_batch.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
normalized_hidden_states = []
|
||||
for layer_idx in range(1, len(hidden_states)):
|
||||
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
|
||||
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
|
||||
)
|
||||
normalized_hidden_states.append(normalized_state)
|
||||
|
||||
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
|
||||
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor],
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames_in: int = 93,
|
||||
num_frames_out: int = 93,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
B = batch_size
|
||||
C = num_channels_latents
|
||||
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||
H = height // self.vae_scale_factor_spatial
|
||||
W = width // self.vae_scale_factor_spatial
|
||||
shape = (B, C, T, H, W)
|
||||
|
||||
if num_frames_in == 0:
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||
|
||||
cond_latents = torch.zeros_like(latents)
|
||||
|
||||
return (
|
||||
latents,
|
||||
cond_latents,
|
||||
cond_mask,
|
||||
cond_indicator,
|
||||
)
|
||||
else:
|
||||
if video is None:
|
||||
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
|
||||
needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3)
|
||||
if needs_preprocessing:
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
cond_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
padding_shape = (B, 1, T, H, W)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
return (
|
||||
latents,
|
||||
cond_latents,
|
||||
cond_mask,
|
||||
cond_indicator,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput | None = None,
|
||||
video: List[PipelineImageInput] | None = None,
|
||||
prompt: Union[str, List[str]] | None = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 93,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 7.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
|
||||
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
||||
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
||||
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
||||
|
||||
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
||||
above in "*2Image mode").
|
||||
|
||||
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||
height (`int`, defaults to `704`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `93`):
|
||||
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `7.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
|
||||
# Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
num_frames_in = None
|
||||
if image is not None:
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
||||
|
||||
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
||||
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
||||
video = video.unsqueeze(0)
|
||||
num_frames_in = 1
|
||||
elif video is None:
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
num_frames_out = num_frames
|
||||
if video.shape[2] < num_frames_out:
|
||||
n_pad_frames = num_frames_out - num_frames_in
|
||||
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
|
||||
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
|
||||
video = torch.cat((video, pad_frames), dim=2)
|
||||
|
||||
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
||||
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=num_frames_in,
|
||||
num_frames_out=num_frames,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
)
|
||||
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
condition_mask=cond_mask,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
condition_mask=cond_mask,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
video = self._match_num_frames(video, num_frames)
|
||||
|
||||
assert self.safety_checker is not None
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
|
||||
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
||||
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
||||
return video
|
||||
|
||||
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
||||
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
||||
|
||||
current_frames = video.shape[2]
|
||||
if current_frames < target_num_frames:
|
||||
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
elif current_frames > target_num_frames:
|
||||
video = video[:, :, :target_num_frames]
|
||||
|
||||
return video
|
||||
@@ -882,24 +882,21 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
latents = latents / latents_std + latents_mean
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
|
||||
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||
img = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||
img = img.squeeze(2)
|
||||
|
||||
image = image.squeeze(2)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
images = []
|
||||
img = self.image_processor.postprocess(img, output_type=output_type)
|
||||
image = []
|
||||
for bidx in range(b):
|
||||
images.append(image[bidx * f : (bidx + 1) * f])
|
||||
image.append(img[bidx * f : (bidx + 1) * f])
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
return (image,)
|
||||
|
||||
return QwenImagePipelineOutput(images=images)
|
||||
return QwenImagePipelineOutput(images=image)
|
||||
|
||||
@@ -217,8 +217,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -352,12 +350,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
if self.config.use_flow_sigmas:
|
||||
sigmas = sigmas / (sigmas + 1)
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
else:
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
|
||||
@@ -767,21 +767,6 @@ class ConsisIDPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False)
|
||||
self._dtype = torch.float32
|
||||
|
||||
def check_text_safety(self, prompt: str) -> bool:
|
||||
return True
|
||||
@@ -35,14 +35,13 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
||||
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
|
||||
return frames
|
||||
|
||||
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None):
|
||||
module = super().to(device=device, dtype=dtype)
|
||||
return module
|
||||
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
|
||||
self._dtype = dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._device_tracker.device
|
||||
return None
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._device_tracker.dtype
|
||||
return self._dtype
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
CosmosTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from .cosmos_guardrail import DummyCosmosSafetyChecker
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(*args, **kwargs):
|
||||
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
|
||||
safety_checker = DummyCosmosSafetyChecker()
|
||||
device_map = kwargs.get("device_map", "cpu")
|
||||
torch_dtype = kwargs.get("torch_dtype")
|
||||
if device_map is not None or torch_dtype is not None:
|
||||
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
|
||||
kwargs["safety_checker"] = safety_checker
|
||||
return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Cosmos2_5_PredictBaseWrapper
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = CosmosTransformer3DModel(
|
||||
in_channels=16 + 1,
|
||||
out_channels=16,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
num_layers=2,
|
||||
mlp_ratio=2,
|
||||
text_embed_dim=32,
|
||||
adaln_lora_dim=4,
|
||||
max_size=(4, 32, 32),
|
||||
patch_size=(1, 2, 2),
|
||||
rope_scale=(2.0, 1.0, 1.0),
|
||||
concat_padding_mask=True,
|
||||
extra_pos_embed_type="learnable",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 16,
|
||||
},
|
||||
hidden_size=16,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": DummyCosmosSafetyChecker(),
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 3,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not getattr(self, "test_attention_slicing", True):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
self.pipeline_class._optional_components.remove("safety_checker")
|
||||
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||
self.pipeline_class._optional_components.append("safety_checker")
|
||||
|
||||
def test_serialization_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
model_components = [
|
||||
component_name
|
||||
for component_name, component in pipe.components.items()
|
||||
if isinstance(component, torch.nn.Module)
|
||||
]
|
||||
model_components.remove("safety_checker")
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
|
||||
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
for subfolder in os.listdir(tmpdir):
|
||||
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||
folder_path = os.path.join(tmpdir, subfolder)
|
||||
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
|
||||
def test_torch_dtype_dict(self):
|
||||
components = self.get_dummy_components()
|
||||
if not components:
|
||||
self.skipTest("No dummy components defined.")
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(
|
||||
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
|
||||
)
|
||||
|
||||
for name, component in loaded_pipe.components.items():
|
||||
if name == "safety_checker":
|
||||
continue
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
|
||||
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
|
||||
self.assertEqual(
|
||||
component.dtype,
|
||||
expected_dtype,
|
||||
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
|
||||
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
|
||||
"too large and slow to run on CI."
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
223
tests/pipelines/qwenimage/test_qwenimage_layered.py
Normal file
223
tests/pipelines/qwenimage/test_qwenimage_layered.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLQwenImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
QwenImageLayeredPipeline,
|
||||
QwenImageTransformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageLayeredPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = QwenImageLayeredPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"height", "width", "cross_attention_kwargs"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset(["image"])
|
||||
image_latents_params = frozenset(["latents"])
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = QwenImageTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=3,
|
||||
joint_attention_dim=16,
|
||||
guidance_embeds=False,
|
||||
axes_dims_rope=(8, 4, 4),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
z_dim = 4
|
||||
vae = AutoencoderKLQwenImage(
|
||||
base_dim=z_dim * 6,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True],
|
||||
latents_mean=[0.0] * z_dim,
|
||||
latents_std=[1.0] * z_dim,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 16,
|
||||
},
|
||||
hidden_size=16,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
|
||||
processor = Qwen2VLProcessor.from_pretrained(tiny_ckpt_id)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"processor": processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"image": Image.new("RGB", (32, 32)),
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"true_cfg_scale": 1.0,
|
||||
"layers": 2,
|
||||
"num_inference_steps": 2,
|
||||
"max_sequence_length": 16,
|
||||
"resolution": 640,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = pipe(**inputs).images
|
||||
self.assertEqual(len(images), 1)
|
||||
|
||||
generated_layers = images[0]
|
||||
self.assertEqual(generated_layers.shape, (inputs["layers"], 3, 640, 640))
|
||||
|
||||
# fmt: off
|
||||
expected_slice_layer_0 = torch.tensor([0.5752, 0.6324, 0.4913, 0.4421, 0.4917, 0.4923, 0.4790, 0.4299, 0.4029, 0.3506, 0.3302, 0.3352, 0.3579, 0.4422, 0.5086, 0.5961])
|
||||
expected_slice_layer_1 = torch.tensor([0.5103, 0.6606, 0.5652, 0.6512, 0.5900, 0.5814, 0.5873, 0.5083, 0.5058, 0.4131, 0.4321, 0.5300, 0.3507, 0.4826, 0.4745, 0.5426])
|
||||
# fmt: on
|
||||
|
||||
layer_0_slice = torch.cat([generated_layers[0].flatten()[:8], generated_layers[0].flatten()[-8:]])
|
||||
layer_1_slice = torch.cat([generated_layers[1].flatten()[:8], generated_layers[1].flatten()[-8:]])
|
||||
|
||||
self.assertTrue(torch.allclose(layer_0_slice, expected_slice_layer_0, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(layer_1_slice, expected_slice_layer_1, atol=1e-3))
|
||||
|
||||
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-1):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = diffusers.logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"]
|
||||
|
||||
output = pipe(**inputs).images
|
||||
output_batch = pipe(**batched_inputs).images
|
||||
|
||||
self.assertEqual(len(output_batch), batch_size)
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
|
||||
self.assertLess(max_diff, expected_max_diff)
|
||||
@@ -399,32 +399,3 @@ class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
|
||||
|
||||
def test_exponential_sigmas(self):
|
||||
self.check_over_configs(use_exponential_sigmas=True)
|
||||
|
||||
def test_flow_and_karras_sigmas(self):
|
||||
self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True)
|
||||
|
||||
def test_flow_and_karras_sigmas_values(self):
|
||||
num_train_timesteps = 1000
|
||||
num_inference_steps = 5
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
sigma_min=0.01,
|
||||
sigma_max=200.0,
|
||||
use_flow_sigmas=True,
|
||||
use_karras_sigmas=True,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
)
|
||||
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
|
||||
|
||||
expected_sigmas = [
|
||||
0.9950248599052429,
|
||||
0.9787454605102539,
|
||||
0.8774884343147278,
|
||||
0.3604971766471863,
|
||||
0.009900986216962337,
|
||||
0.0, # 0 appended as default
|
||||
]
|
||||
expected_sigmas = torch.tensor(expected_sigmas)
|
||||
expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64)
|
||||
expected_timesteps = expected_timesteps[0:-1]
|
||||
self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas))
|
||||
self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps))
|
||||
|
||||
Reference in New Issue
Block a user