mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 12:54:48 +08:00
Compare commits
6 Commits
naykun-mai
...
remove-unn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10dfa9b722 | ||
|
|
262ce19bff | ||
|
|
f7753b1bc8 | ||
|
|
b5309683cb | ||
|
|
55463f7ace | ||
|
|
f9c1e612fb |
@@ -70,6 +70,12 @@ output.save("output.png")
|
|||||||
- all
|
- all
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
## Cosmos2_5_PredictBasePipeline
|
||||||
|
|
||||||
|
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||||
|
- all
|
||||||
|
- __call__
|
||||||
|
|
||||||
## CosmosPipelineOutput
|
## CosmosPipelineOutput
|
||||||
|
|
||||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||||
|
|||||||
@@ -1,11 +1,55 @@
|
|||||||
|
"""
|
||||||
|
# Cosmos 2 Predict
|
||||||
|
|
||||||
|
Download checkpoint
|
||||||
|
```bash
|
||||||
|
hf download nvidia/Cosmos-Predict2-2B-Text2Image
|
||||||
|
```
|
||||||
|
|
||||||
|
convert checkpoint
|
||||||
|
```bash
|
||||||
|
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
|
||||||
|
|
||||||
|
python scripts/convert_cosmos_to_diffusers.py \
|
||||||
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
|
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
|
||||||
|
--text_encoder_path google-t5/t5-11b \
|
||||||
|
--tokenizer_path google-t5/t5-11b \
|
||||||
|
--vae_type wan2.1 \
|
||||||
|
--output_path converted/cosmos-p2-t2i-2b \
|
||||||
|
--save_pipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
# Cosmos 2.5 Predict
|
||||||
|
|
||||||
|
Download checkpoint
|
||||||
|
```bash
|
||||||
|
hf download nvidia/Cosmos-Predict2.5-2B
|
||||||
|
```
|
||||||
|
|
||||||
|
Convert checkpoint
|
||||||
|
```bash
|
||||||
|
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
|
||||||
|
|
||||||
|
python scripts/convert_cosmos_to_diffusers.py \
|
||||||
|
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||||
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
|
--vae_type wan2.1 \
|
||||||
|
--output_path converted/cosmos-p2.5-base-2b \
|
||||||
|
--save_pipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import sys
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import T5EncoderModel, T5TokenizerFast
|
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKLCosmos,
|
AutoencoderKLCosmos,
|
||||||
@@ -17,7 +61,9 @@ from diffusers import (
|
|||||||
CosmosVideoToWorldPipeline,
|
CosmosVideoToWorldPipeline,
|
||||||
EDMEulerScheduler,
|
EDMEulerScheduler,
|
||||||
FlowMatchEulerDiscreteScheduler,
|
FlowMatchEulerDiscreteScheduler,
|
||||||
|
UniPCMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
|
||||||
|
|
||||||
|
|
||||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||||
@@ -233,6 +279,25 @@ TRANSFORMER_CONFIGS = {
|
|||||||
"concat_padding_mask": True,
|
"concat_padding_mask": True,
|
||||||
"extra_pos_embed_type": None,
|
"extra_pos_embed_type": None,
|
||||||
},
|
},
|
||||||
|
"Cosmos-2.5-Predict-Base-2B": {
|
||||||
|
"in_channels": 16 + 1,
|
||||||
|
"out_channels": 16,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"attention_head_dim": 128,
|
||||||
|
"num_layers": 28,
|
||||||
|
"mlp_ratio": 4.0,
|
||||||
|
"text_embed_dim": 1024,
|
||||||
|
"adaln_lora_dim": 256,
|
||||||
|
"max_size": (128, 240, 240),
|
||||||
|
"patch_size": (1, 2, 2),
|
||||||
|
"rope_scale": (1.0, 3.0, 3.0),
|
||||||
|
"concat_padding_mask": True,
|
||||||
|
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||||
|
"extra_pos_embed_type": None,
|
||||||
|
"use_crossattn_projection": True,
|
||||||
|
"crossattn_proj_in_channels": 100352,
|
||||||
|
"encoder_hidden_states_channels": 1024,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VAE_KEYS_RENAME_DICT = {
|
VAE_KEYS_RENAME_DICT = {
|
||||||
@@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
|||||||
elif "Cosmos-2.0" in transformer_type:
|
elif "Cosmos-2.0" in transformer_type:
|
||||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||||
|
elif "Cosmos-2.5" in transformer_type:
|
||||||
|
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||||
|
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
@@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
|||||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||||
new_key = new_key.replace(replace_key, rename_key)
|
new_key = new_key.replace(replace_key, rename_key)
|
||||||
|
print(key, "->", new_key, flush=True)
|
||||||
update_state_dict_(original_state_dict, key, new_key)
|
update_state_dict_(original_state_dict, key, new_key)
|
||||||
|
|
||||||
for key in list(original_state_dict.keys()):
|
for key in list(original_state_dict.keys()):
|
||||||
@@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
|||||||
continue
|
continue
|
||||||
handler_fn_inplace(key, original_state_dict)
|
handler_fn_inplace(key, original_state_dict)
|
||||||
|
|
||||||
|
expected_keys = set(transformer.state_dict().keys())
|
||||||
|
mapped_keys = set(original_state_dict.keys())
|
||||||
|
missing_keys = expected_keys - mapped_keys
|
||||||
|
unexpected_keys = mapped_keys - expected_keys
|
||||||
|
if missing_keys:
|
||||||
|
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
|
||||||
|
for k in missing_keys:
|
||||||
|
print(k)
|
||||||
|
sys.exit(1)
|
||||||
|
if unexpected_keys:
|
||||||
|
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
|
||||||
|
for k in unexpected_keys:
|
||||||
|
print(k)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||||
return transformer
|
return transformer
|
||||||
|
|
||||||
@@ -444,6 +528,34 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
|
|||||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||||
|
|
||||||
|
|
||||||
|
def save_pipeline_cosmos2_5(args, transformer, vae):
|
||||||
|
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
|
||||||
|
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||||
|
|
||||||
|
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
text_encoder_path, torch_dtype="auto", device_map="cpu"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
use_karras_sigmas=True,
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
prediction_type="flow_prediction",
|
||||||
|
sigma_max=200.0,
|
||||||
|
sigma_min=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = Cosmos2_5_PredictBasePipeline(
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
transformer=transformer,
|
||||||
|
vae=vae,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=lambda *args, **kwargs: None,
|
||||||
|
)
|
||||||
|
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||||
@@ -451,10 +563,10 @@ def get_args():
|
|||||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||||
)
|
)
|
||||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
parser.add_argument("--text_encoder_path", type=str, default=None)
|
||||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
parser.add_argument("--tokenizer_path", type=str, default=None)
|
||||||
parser.add_argument("--save_pipeline", action="store_true")
|
parser.add_argument("--save_pipeline", action="store_true")
|
||||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||||
@@ -477,8 +589,6 @@ if __name__ == "__main__":
|
|||||||
if args.save_pipeline:
|
if args.save_pipeline:
|
||||||
assert args.transformer_ckpt_path is not None
|
assert args.transformer_ckpt_path is not None
|
||||||
assert args.vae_type is not None
|
assert args.vae_type is not None
|
||||||
assert args.text_encoder_path is not None
|
|
||||||
assert args.tokenizer_path is not None
|
|
||||||
|
|
||||||
if args.transformer_ckpt_path is not None:
|
if args.transformer_ckpt_path is not None:
|
||||||
weights_only = "Cosmos-1.0" in args.transformer_type
|
weights_only = "Cosmos-1.0" in args.transformer_type
|
||||||
@@ -490,17 +600,26 @@ if __name__ == "__main__":
|
|||||||
if args.vae_type is not None:
|
if args.vae_type is not None:
|
||||||
if "Cosmos-1.0" in args.transformer_type:
|
if "Cosmos-1.0" in args.transformer_type:
|
||||||
vae = convert_vae(args.vae_type)
|
vae = convert_vae(args.vae_type)
|
||||||
else:
|
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
|
||||||
vae = AutoencoderKLWan.from_pretrained(
|
vae = AutoencoderKLWan.from_pretrained(
|
||||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"{args.transformer_type} not supported")
|
||||||
|
|
||||||
if not args.save_pipeline:
|
if not args.save_pipeline:
|
||||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||||
|
|
||||||
if args.save_pipeline:
|
if args.save_pipeline:
|
||||||
if "Cosmos-1.0" in args.transformer_type:
|
if "Cosmos-1.0" in args.transformer_type:
|
||||||
|
assert args.text_encoder_path is not None
|
||||||
|
assert args.tokenizer_path is not None
|
||||||
save_pipeline_cosmos_1_0(args, transformer, vae)
|
save_pipeline_cosmos_1_0(args, transformer, vae)
|
||||||
elif "Cosmos-2.0" in args.transformer_type:
|
elif "Cosmos-2.0" in args.transformer_type:
|
||||||
|
assert args.text_encoder_path is not None
|
||||||
|
assert args.tokenizer_path is not None
|
||||||
save_pipeline_cosmos_2_0(args, transformer, vae)
|
save_pipeline_cosmos_2_0(args, transformer, vae)
|
||||||
|
elif "Cosmos-2.5" in args.transformer_type:
|
||||||
|
save_pipeline_cosmos2_5(args, transformer, vae)
|
||||||
else:
|
else:
|
||||||
assert False
|
raise AssertionError(f"{args.transformer_type} not supported")
|
||||||
|
|||||||
@@ -279,6 +279,7 @@ else:
|
|||||||
"WanAnimateTransformer3DModel",
|
"WanAnimateTransformer3DModel",
|
||||||
"WanTransformer3DModel",
|
"WanTransformer3DModel",
|
||||||
"WanVACETransformer3DModel",
|
"WanVACETransformer3DModel",
|
||||||
|
"ZImageControlNetModel",
|
||||||
"ZImageTransformer2DModel",
|
"ZImageTransformer2DModel",
|
||||||
"attention_backend",
|
"attention_backend",
|
||||||
]
|
]
|
||||||
@@ -462,6 +463,7 @@ else:
|
|||||||
"CogView4ControlPipeline",
|
"CogView4ControlPipeline",
|
||||||
"CogView4Pipeline",
|
"CogView4Pipeline",
|
||||||
"ConsisIDPipeline",
|
"ConsisIDPipeline",
|
||||||
|
"Cosmos2_5_PredictBasePipeline",
|
||||||
"Cosmos2TextToImagePipeline",
|
"Cosmos2TextToImagePipeline",
|
||||||
"Cosmos2VideoToWorldPipeline",
|
"Cosmos2VideoToWorldPipeline",
|
||||||
"CosmosTextToWorldPipeline",
|
"CosmosTextToWorldPipeline",
|
||||||
@@ -564,6 +566,7 @@ else:
|
|||||||
"QwenImageEditPlusPipeline",
|
"QwenImageEditPlusPipeline",
|
||||||
"QwenImageImg2ImgPipeline",
|
"QwenImageImg2ImgPipeline",
|
||||||
"QwenImageInpaintPipeline",
|
"QwenImageInpaintPipeline",
|
||||||
|
"QwenImageLayeredPipeline",
|
||||||
"QwenImagePipeline",
|
"QwenImagePipeline",
|
||||||
"ReduxImageEncoder",
|
"ReduxImageEncoder",
|
||||||
"SanaControlNetPipeline",
|
"SanaControlNetPipeline",
|
||||||
@@ -669,6 +672,8 @@ else:
|
|||||||
"WuerstchenCombinedPipeline",
|
"WuerstchenCombinedPipeline",
|
||||||
"WuerstchenDecoderPipeline",
|
"WuerstchenDecoderPipeline",
|
||||||
"WuerstchenPriorPipeline",
|
"WuerstchenPriorPipeline",
|
||||||
|
"ZImageControlNetInpaintPipeline",
|
||||||
|
"ZImageControlNetPipeline",
|
||||||
"ZImageImg2ImgPipeline",
|
"ZImageImg2ImgPipeline",
|
||||||
"ZImagePipeline",
|
"ZImagePipeline",
|
||||||
]
|
]
|
||||||
@@ -1016,6 +1021,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
WanAnimateTransformer3DModel,
|
WanAnimateTransformer3DModel,
|
||||||
WanTransformer3DModel,
|
WanTransformer3DModel,
|
||||||
WanVACETransformer3DModel,
|
WanVACETransformer3DModel,
|
||||||
|
ZImageControlNetModel,
|
||||||
ZImageTransformer2DModel,
|
ZImageTransformer2DModel,
|
||||||
attention_backend,
|
attention_backend,
|
||||||
)
|
)
|
||||||
@@ -1170,6 +1176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
CogView4ControlPipeline,
|
CogView4ControlPipeline,
|
||||||
CogView4Pipeline,
|
CogView4Pipeline,
|
||||||
ConsisIDPipeline,
|
ConsisIDPipeline,
|
||||||
|
Cosmos2_5_PredictBasePipeline,
|
||||||
Cosmos2TextToImagePipeline,
|
Cosmos2TextToImagePipeline,
|
||||||
Cosmos2VideoToWorldPipeline,
|
Cosmos2VideoToWorldPipeline,
|
||||||
CosmosTextToWorldPipeline,
|
CosmosTextToWorldPipeline,
|
||||||
@@ -1272,6 +1279,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
QwenImageEditPlusPipeline,
|
QwenImageEditPlusPipeline,
|
||||||
QwenImageImg2ImgPipeline,
|
QwenImageImg2ImgPipeline,
|
||||||
QwenImageInpaintPipeline,
|
QwenImageInpaintPipeline,
|
||||||
|
QwenImageLayeredPipeline,
|
||||||
QwenImagePipeline,
|
QwenImagePipeline,
|
||||||
ReduxImageEncoder,
|
ReduxImageEncoder,
|
||||||
SanaControlNetPipeline,
|
SanaControlNetPipeline,
|
||||||
@@ -1375,6 +1383,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
WuerstchenCombinedPipeline,
|
WuerstchenCombinedPipeline,
|
||||||
WuerstchenDecoderPipeline,
|
WuerstchenDecoderPipeline,
|
||||||
WuerstchenPriorPipeline,
|
WuerstchenPriorPipeline,
|
||||||
|
ZImageControlNetInpaintPipeline,
|
||||||
|
ZImageControlNetPipeline,
|
||||||
ZImageImg2ImgPipeline,
|
ZImageImg2ImgPipeline,
|
||||||
ZImagePipeline,
|
ZImagePipeline,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ if is_torch_available():
|
|||||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||||
from .guider_utils import BaseGuidance
|
from .guider_utils import BaseGuidance
|
||||||
|
from .magnitude_aware_guidance import MagnitudeAwareGuidance
|
||||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||||
from .skip_layer_guidance import SkipLayerGuidance
|
from .skip_layer_guidance import SkipLayerGuidance
|
||||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||||
|
|||||||
159
src/diffusers/guiders/magnitude_aware_guidance.py
Normal file
159
src/diffusers/guiders/magnitude_aware_guidance.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..configuration_utils import register_to_config
|
||||||
|
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
|
class MagnitudeAwareGuidance(BaseGuidance):
|
||||||
|
"""
|
||||||
|
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guidance_scale (`float`, defaults to `10.0`):
|
||||||
|
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||||
|
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||||
|
deterioration of image quality.
|
||||||
|
alpha (`float`, defaults to `8.0`):
|
||||||
|
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
|
||||||
|
guidance scale when the magnitude of the guidance update is large.
|
||||||
|
guidance_rescale (`float`, defaults to `0.0`):
|
||||||
|
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||||
|
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||||
|
Flawed](https://huggingface.co/papers/2305.08891).
|
||||||
|
use_original_formulation (`bool`, defaults to `False`):
|
||||||
|
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||||
|
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||||
|
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||||
|
start (`float`, defaults to `0.0`):
|
||||||
|
The fraction of the total number of denoising steps after which guidance starts.
|
||||||
|
stop (`float`, defaults to `1.0`):
|
||||||
|
The fraction of the total number of denoising steps after which guidance stops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guidance_scale: float = 10.0,
|
||||||
|
alpha: float = 8.0,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
use_original_formulation: bool = False,
|
||||||
|
start: float = 0.0,
|
||||||
|
stop: float = 1.0,
|
||||||
|
enabled: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(start, stop, enabled)
|
||||||
|
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
self.alpha = alpha
|
||||||
|
self.guidance_rescale = guidance_rescale
|
||||||
|
self.use_original_formulation = use_original_formulation
|
||||||
|
|
||||||
|
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||||
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
|
data_batches = []
|
||||||
|
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||||
|
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||||
|
data_batches.append(data_batch)
|
||||||
|
return data_batches
|
||||||
|
|
||||||
|
def prepare_inputs_from_block_state(
|
||||||
|
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||||
|
) -> List["BlockState"]:
|
||||||
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
|
data_batches = []
|
||||||
|
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||||
|
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||||
|
data_batches.append(data_batch)
|
||||||
|
return data_batches
|
||||||
|
|
||||||
|
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||||
|
pred = None
|
||||||
|
|
||||||
|
if not self._is_mambo_g_enabled():
|
||||||
|
pred = pred_cond
|
||||||
|
else:
|
||||||
|
pred = mambo_guidance(
|
||||||
|
pred_cond,
|
||||||
|
pred_uncond,
|
||||||
|
self.guidance_scale,
|
||||||
|
self.alpha,
|
||||||
|
self.use_original_formulation,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.guidance_rescale > 0.0:
|
||||||
|
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||||
|
|
||||||
|
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_conditional(self) -> bool:
|
||||||
|
return self._count_prepared == 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_conditions(self) -> int:
|
||||||
|
num_conditions = 1
|
||||||
|
if self._is_mambo_g_enabled():
|
||||||
|
num_conditions += 1
|
||||||
|
return num_conditions
|
||||||
|
|
||||||
|
def _is_mambo_g_enabled(self) -> bool:
|
||||||
|
if not self._enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_within_range = True
|
||||||
|
if self._num_inference_steps is not None:
|
||||||
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
|
is_close = False
|
||||||
|
if self.use_original_formulation:
|
||||||
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
|
else:
|
||||||
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
|
return is_within_range and not is_close
|
||||||
|
|
||||||
|
|
||||||
|
def mambo_guidance(
|
||||||
|
pred_cond: torch.Tensor,
|
||||||
|
pred_uncond: torch.Tensor,
|
||||||
|
guidance_scale: float,
|
||||||
|
alpha: float = 8.0,
|
||||||
|
use_original_formulation: bool = False,
|
||||||
|
):
|
||||||
|
dim = list(range(1, len(pred_cond.shape)))
|
||||||
|
diff = pred_cond - pred_uncond
|
||||||
|
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
|
||||||
|
guidance_scale_final = (
|
||||||
|
guidance_scale * torch.exp(-alpha * ratio)
|
||||||
|
if use_original_formulation
|
||||||
|
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
|
||||||
|
)
|
||||||
|
pred = pred_cond if use_original_formulation else pred_uncond
|
||||||
|
pred = pred + guidance_scale_final * diff
|
||||||
|
|
||||||
|
return pred
|
||||||
@@ -49,6 +49,7 @@ from .single_file_utils import (
|
|||||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||||
convert_wan_transformer_to_diffusers,
|
convert_wan_transformer_to_diffusers,
|
||||||
convert_wan_vae_to_diffusers,
|
convert_wan_vae_to_diffusers,
|
||||||
|
convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||||
convert_z_image_transformer_checkpoint_to_diffusers,
|
convert_z_image_transformer_checkpoint_to_diffusers,
|
||||||
create_controlnet_diffusers_config_from_ldm,
|
create_controlnet_diffusers_config_from_ldm,
|
||||||
create_unet_diffusers_config_from_ldm,
|
create_unet_diffusers_config_from_ldm,
|
||||||
@@ -172,11 +173,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|||||||
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
||||||
"default_subfolder": "transformer",
|
"default_subfolder": "transformer",
|
||||||
},
|
},
|
||||||
|
"ZImageControlNetModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||||
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
model_state_dict_keys = set(model_state_dict.keys())
|
||||||
|
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
|
||||||
|
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
|
||||||
|
is_match = model_state_dict_keys == checkpoint_state_dict_keys
|
||||||
|
return not (is_subset and is_match)
|
||||||
|
|
||||||
|
|
||||||
def _get_single_file_loadable_mapping_class(cls):
|
def _get_single_file_loadable_mapping_class(cls):
|
||||||
|
|||||||
@@ -121,6 +121,8 @@ CHECKPOINT_KEY_NAMES = {
|
|||||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||||
"z-image-turbo": "cap_embedder.0.weight",
|
"z-image-turbo": "cap_embedder.0.weight",
|
||||||
|
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
|
||||||
|
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
|
||||||
"sana": [
|
"sana": [
|
||||||
"blocks.0.cross_attn.q_linear.weight",
|
"blocks.0.cross_attn.q_linear.weight",
|
||||||
"blocks.0.cross_attn.q_linear.bias",
|
"blocks.0.cross_attn.q_linear.bias",
|
||||||
@@ -220,6 +222,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|||||||
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
||||||
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
||||||
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
|
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
|
||||||
|
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
|
||||||
|
"z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use to configure model sample size when original config is provided
|
# Use to configure model sample size when original config is provided
|
||||||
@@ -779,6 +783,12 @@ def infer_diffusers_model_type(checkpoint):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
||||||
|
|
||||||
|
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
|
||||||
|
model_type = "z-image-turbo-controlnet-2.x"
|
||||||
|
|
||||||
|
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
|
||||||
|
model_type = "z-image-turbo-controlnet"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_type = "v1"
|
model_type = "v1"
|
||||||
|
|
||||||
@@ -3885,3 +3895,17 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|||||||
handler_fn_inplace(key, converted_state_dict)
|
handler_fn_inplace(key, converted_state_dict)
|
||||||
|
|
||||||
return converted_state_dict
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
|
||||||
|
if config["add_control_noise_refiner"] is None:
|
||||||
|
return checkpoint
|
||||||
|
elif config["add_control_noise_refiner"] == "control_noise_refiner":
|
||||||
|
return checkpoint
|
||||||
|
elif config["add_control_noise_refiner"] == "control_layers":
|
||||||
|
converted_state_dict = {
|
||||||
|
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
|
||||||
|
}
|
||||||
|
return converted_state_dict
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ if is_torch_available():
|
|||||||
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||||
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
||||||
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||||
|
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
|
||||||
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
||||||
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
|
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
|
||||||
_import_structure["embeddings"] = ["ImageProjection"]
|
_import_structure["embeddings"] = ["ImageProjection"]
|
||||||
@@ -181,6 +182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
SD3MultiControlNetModel,
|
SD3MultiControlNetModel,
|
||||||
SparseControlNetModel,
|
SparseControlNetModel,
|
||||||
UNetControlNetXSModel,
|
UNetControlNetXSModel,
|
||||||
|
ZImageControlNetModel,
|
||||||
)
|
)
|
||||||
from .embeddings import ImageProjection
|
from .embeddings import ImageProjection
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
|||||||
from ..activations import get_activation
|
from ..activations import get_activation
|
||||||
from ..modeling_outputs import AutoencoderKLOutput
|
from ..modeling_outputs import AutoencoderKLOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||||
r"""
|
r"""
|
||||||
A VAE model for 2D images with spatial tiling support.
|
A VAE model for 2D images with spatial tiling support.
|
||||||
|
|
||||||
@@ -486,27 +486,6 @@ class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin)
|
|||||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||||
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
||||||
|
|
||||||
def disable_tiling(self) -> None:
|
|
||||||
r"""
|
|
||||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_tiling = False
|
|
||||||
|
|
||||||
def enable_slicing(self) -> None:
|
|
||||||
r"""
|
|
||||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
|
||||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
|
||||||
"""
|
|
||||||
self.use_slicing = True
|
|
||||||
|
|
||||||
def disable_slicing(self) -> None:
|
|
||||||
r"""
|
|
||||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_slicing = False
|
|
||||||
|
|
||||||
def _encode(self, x: torch.Tensor):
|
def _encode(self, x: torch.Tensor):
|
||||||
|
|
||||||
batch_size, num_channels, height, width = x.shape
|
batch_size, num_channels, height, width = x.shape
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
|||||||
from ..activations import get_activation
|
from ..activations import get_activation
|
||||||
from ..modeling_outputs import AutoencoderKLOutput
|
from ..modeling_outputs import AutoencoderKLOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||||
r"""
|
r"""
|
||||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||||
HunyuanImage-2.1 Refiner.
|
HunyuanImage-2.1 Refiner.
|
||||||
@@ -685,27 +685,6 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
|||||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||||
|
|
||||||
def disable_tiling(self) -> None:
|
|
||||||
r"""
|
|
||||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_tiling = False
|
|
||||||
|
|
||||||
def enable_slicing(self) -> None:
|
|
||||||
r"""
|
|
||||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
|
||||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
|
||||||
"""
|
|
||||||
self.use_slicing = True
|
|
||||||
|
|
||||||
def disable_slicing(self) -> None:
|
|
||||||
r"""
|
|
||||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_slicing = False
|
|
||||||
|
|
||||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
_, _, _, height, width = x.shape
|
_, _, _, height, width = x.shape
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
|||||||
from ..activations import get_activation
|
from ..activations import get_activation
|
||||||
from ..modeling_outputs import AutoencoderKLOutput
|
from ..modeling_outputs import AutoencoderKLOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||||
r"""
|
r"""
|
||||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||||
HunyuanVideo-1.5.
|
HunyuanVideo-1.5.
|
||||||
@@ -723,27 +723,6 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
|||||||
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
||||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||||
|
|
||||||
def disable_tiling(self) -> None:
|
|
||||||
r"""
|
|
||||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_tiling = False
|
|
||||||
|
|
||||||
def enable_slicing(self) -> None:
|
|
||||||
r"""
|
|
||||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
|
||||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
|
||||||
"""
|
|
||||||
self.use_slicing = True
|
|
||||||
|
|
||||||
def disable_slicing(self) -> None:
|
|
||||||
r"""
|
|
||||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_slicing = False
|
|
||||||
|
|
||||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
_, _, _, height, width = x.shape
|
_, _, _, height, width = x.shape
|
||||||
|
|
||||||
|
|||||||
@@ -394,6 +394,7 @@ class QwenImageEncoder3d(nn.Module):
|
|||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
|
input_channels=3,
|
||||||
non_linearity: str = "silu",
|
non_linearity: str = "silu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -410,7 +411,7 @@ class QwenImageEncoder3d(nn.Module):
|
|||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
# init block
|
# init block
|
||||||
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# downsample blocks
|
# downsample blocks
|
||||||
self.down_blocks = nn.ModuleList([])
|
self.down_blocks = nn.ModuleList([])
|
||||||
@@ -570,6 +571,7 @@ class QwenImageDecoder3d(nn.Module):
|
|||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_upsample=[False, True, True],
|
temperal_upsample=[False, True, True],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
|
input_channels=3,
|
||||||
non_linearity: str = "silu",
|
non_linearity: str = "silu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -621,7 +623,7 @@ class QwenImageDecoder3d(nn.Module):
|
|||||||
|
|
||||||
# output blocks
|
# output blocks
|
||||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||||
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
@@ -684,6 +686,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
|||||||
attn_scales: List[float] = [],
|
attn_scales: List[float] = [],
|
||||||
temperal_downsample: List[bool] = [False, True, True],
|
temperal_downsample: List[bool] = [False, True, True],
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
|
input_channels: int = 3,
|
||||||
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
||||||
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -695,13 +698,13 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
|||||||
self.temperal_upsample = temperal_downsample[::-1]
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
self.encoder = QwenImageEncoder3d(
|
self.encoder = QwenImageEncoder3d(
|
||||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels
|
||||||
)
|
)
|
||||||
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||||
|
|
||||||
self.decoder = QwenImageDecoder3d(
|
self.decoder = QwenImageDecoder3d(
|
||||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from .controlnet_union import ControlNetUnionModel
|
from .controlnet_union import ControlNetUnionModel
|
||||||
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
||||||
|
from .controlnet_z_image import ZImageControlNetModel
|
||||||
from .multicontrolnet import MultiControlNetModel
|
from .multicontrolnet import MultiControlNetModel
|
||||||
from .multicontrolnet_union import MultiControlNetUnionModel
|
from .multicontrolnet_union import MultiControlNetUnionModel
|
||||||
|
|
||||||
|
|||||||
824
src/diffusers/models/controlnets/controlnet_z_image.py
Normal file
824
src/diffusers/models/controlnets/controlnet_z_image.py
Normal file
@@ -0,0 +1,824 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ...loaders import PeftAdapterMixin
|
||||||
|
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||||
|
from ...models.attention_processor import Attention
|
||||||
|
from ...models.normalization import RMSNorm
|
||||||
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
|
from ..attention_dispatch import dispatch_attention_fn
|
||||||
|
from ..controlnets.controlnet import zero_module
|
||||||
|
from ..modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
|
ADALN_EMBED_DIM = 256
|
||||||
|
SEQ_MULTI_OF = 32
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||||
|
super().__init__()
|
||||||
|
if mid_size is None:
|
||||||
|
mid_size = out_size
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_size, mid_size, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(mid_size, out_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||||
|
)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
|
weight_dtype = self.mlp[0].weight.dtype
|
||||||
|
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
|
||||||
|
if weight_dtype.is_floating_point:
|
||||||
|
t_freq = t_freq.to(weight_dtype)
|
||||||
|
elif compute_dtype is not None:
|
||||||
|
t_freq = t_freq.to(compute_dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor
|
||||||
|
class ZSingleStreamAttnProcessor:
|
||||||
|
"""
|
||||||
|
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
||||||
|
original Z-ImageAttention module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_attention_backend = None
|
||||||
|
_parallel_config = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
|
||||||
|
query = query.unflatten(-1, (attn.heads, -1))
|
||||||
|
key = key.unflatten(-1, (attn.heads, -1))
|
||||||
|
value = value.unflatten(-1, (attn.heads, -1))
|
||||||
|
|
||||||
|
# Apply Norms
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# Apply RoPE
|
||||||
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
|
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||||
|
return x_out.type_as(x_in) # todo
|
||||||
|
|
||||||
|
if freqs_cis is not None:
|
||||||
|
query = apply_rotary_emb(query, freqs_cis)
|
||||||
|
key = apply_rotary_emb(key, freqs_cis)
|
||||||
|
|
||||||
|
# Cast to correct dtype
|
||||||
|
dtype = query.dtype
|
||||||
|
query, key = query.to(dtype), key.to(dtype)
|
||||||
|
|
||||||
|
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||||
|
if attention_mask is not None and attention_mask.ndim == 2:
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
# Compute joint attention
|
||||||
|
hidden_states = dispatch_attention_fn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False,
|
||||||
|
backend=self._attention_backend,
|
||||||
|
parallel_config=self._parallel_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
hidden_states = hidden_states.flatten(2, 3)
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
|
||||||
|
output = attn.to_out[0](hidden_states)
|
||||||
|
if len(attn.to_out) > 1: # dropout
|
||||||
|
output = attn.to_out[1](output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.FeedForward
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def _forward_silu_gating(self, x1, x3):
|
||||||
|
return F.silu(x1) * x3
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
|
||||||
|
class ZImageTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
# Refactored to use diffusers Attention with custom processor
|
||||||
|
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||||
|
self.attention = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
dim_head=dim // n_heads,
|
||||||
|
heads=n_heads,
|
||||||
|
qk_norm="rms_norm" if qk_norm else None,
|
||||||
|
eps=1e-5,
|
||||||
|
bias=False,
|
||||||
|
out_bias=False,
|
||||||
|
processor=ZSingleStreamAttnProcessor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.modulation = modulation
|
||||||
|
if modulation:
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
adaln_input: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.modulation:
|
||||||
|
assert adaln_input is not None
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||||
|
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||||
|
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||||
|
|
||||||
|
# Attention block
|
||||||
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||||
|
)
|
||||||
|
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||||
|
|
||||||
|
# FFN block
|
||||||
|
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||||
|
else:
|
||||||
|
# Attention block
|
||||||
|
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||||
|
x = x + self.attention_norm2(attn_out)
|
||||||
|
|
||||||
|
# FFN block
|
||||||
|
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder
|
||||||
|
class RopeEmbedder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
theta: float = 256.0,
|
||||||
|
axes_dims: List[int] = (16, 56, 56),
|
||||||
|
axes_lens: List[int] = (64, 128, 128),
|
||||||
|
):
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dims = axes_dims
|
||||||
|
self.axes_lens = axes_lens
|
||||||
|
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||||
|
self.freqs_cis = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||||
|
with torch.device("cpu"):
|
||||||
|
freqs_cis = []
|
||||||
|
for i, (d, e) in enumerate(zip(dim, end)):
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||||
|
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||||
|
freqs = torch.outer(timestep, freqs).float()
|
||||||
|
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||||
|
freqs_cis.append(freqs_cis_i)
|
||||||
|
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def __call__(self, ids: torch.Tensor):
|
||||||
|
assert ids.ndim == 2
|
||||||
|
assert ids.shape[-1] == len(self.axes_dims)
|
||||||
|
device = ids.device
|
||||||
|
|
||||||
|
if self.freqs_cis is None:
|
||||||
|
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||||
|
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||||
|
else:
|
||||||
|
# Ensure freqs_cis are on the same device as ids
|
||||||
|
if self.freqs_cis[0].device != device:
|
||||||
|
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for i in range(len(self.axes_dims)):
|
||||||
|
index = ids[:, i]
|
||||||
|
result.append(self.freqs_cis[i][index])
|
||||||
|
return torch.cat(result, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
class ZImageControlTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=True,
|
||||||
|
block_id=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
# Refactored to use diffusers Attention with custom processor
|
||||||
|
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||||
|
self.attention = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
dim_head=dim // n_heads,
|
||||||
|
heads=n_heads,
|
||||||
|
qk_norm="rms_norm" if qk_norm else None,
|
||||||
|
eps=1e-5,
|
||||||
|
bias=False,
|
||||||
|
out_bias=False,
|
||||||
|
processor=ZSingleStreamAttnProcessor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
self.modulation = modulation
|
||||||
|
if modulation:
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||||
|
|
||||||
|
# Control variant start
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
|
||||||
|
self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
c: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
adaln_input: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Control
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c))
|
||||||
|
c = all_c.pop(-1)
|
||||||
|
|
||||||
|
# Compared to `ZImageTransformerBlock` x -> c
|
||||||
|
if self.modulation:
|
||||||
|
assert adaln_input is not None
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||||
|
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||||
|
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||||
|
|
||||||
|
# Attention block
|
||||||
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||||
|
)
|
||||||
|
c = c + gate_msa * self.attention_norm2(attn_out)
|
||||||
|
|
||||||
|
# FFN block
|
||||||
|
c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp))
|
||||||
|
else:
|
||||||
|
# Attention block
|
||||||
|
attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||||
|
c = c + self.attention_norm2(attn_out)
|
||||||
|
|
||||||
|
# FFN block
|
||||||
|
c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c)))
|
||||||
|
|
||||||
|
# Control
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
all_c += [c_skip, c]
|
||||||
|
c = torch.stack(all_c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
control_layers_places: List[int] = None,
|
||||||
|
control_refiner_layers_places: List[int] = None,
|
||||||
|
control_in_dim=None,
|
||||||
|
add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None,
|
||||||
|
all_patch_size=(2,),
|
||||||
|
all_f_patch_size=(1,),
|
||||||
|
dim=3840,
|
||||||
|
n_refiner_layers=2,
|
||||||
|
n_heads=30,
|
||||||
|
n_kv_heads=30,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
qk_norm=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.control_layers_places = control_layers_places
|
||||||
|
self.control_in_dim = control_in_dim
|
||||||
|
self.control_refiner_layers_places = control_refiner_layers_places
|
||||||
|
self.add_control_noise_refiner = add_control_noise_refiner
|
||||||
|
|
||||||
|
assert 0 in self.control_layers_places
|
||||||
|
|
||||||
|
# control blocks
|
||||||
|
self.control_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i)
|
||||||
|
for i in self.control_layers_places
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# control patch embeddings
|
||||||
|
all_x_embedder = {}
|
||||||
|
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||||
|
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
|
||||||
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
|
if self.add_control_noise_refiner == "control_layers":
|
||||||
|
self.control_noise_refiner = None
|
||||||
|
elif self.add_control_noise_refiner == "control_noise_refiner":
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageControlTransformerBlock(
|
||||||
|
1000 + layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
block_id=layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(
|
||||||
|
1000 + layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t_scale: Optional[float] = None
|
||||||
|
self.t_embedder: Optional[TimestepEmbedder] = None
|
||||||
|
self.all_x_embedder: Optional[nn.ModuleDict] = None
|
||||||
|
self.cap_embedder: Optional[nn.Sequential] = None
|
||||||
|
self.rope_embedder: Optional[RopeEmbedder] = None
|
||||||
|
self.noise_refiner: Optional[nn.ModuleList] = None
|
||||||
|
self.context_refiner: Optional[nn.ModuleList] = None
|
||||||
|
self.x_pad_token: Optional[nn.Parameter] = None
|
||||||
|
self.cap_pad_token: Optional[nn.Parameter] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_transformer(cls, controlnet, transformer):
|
||||||
|
controlnet.t_scale = transformer.t_scale
|
||||||
|
controlnet.t_embedder = transformer.t_embedder
|
||||||
|
controlnet.all_x_embedder = transformer.all_x_embedder
|
||||||
|
controlnet.cap_embedder = transformer.cap_embedder
|
||||||
|
controlnet.rope_embedder = transformer.rope_embedder
|
||||||
|
controlnet.noise_refiner = transformer.noise_refiner
|
||||||
|
controlnet.context_refiner = transformer.context_refiner
|
||||||
|
controlnet.x_pad_token = transformer.x_pad_token
|
||||||
|
controlnet.cap_pad_token = transformer.cap_pad_token
|
||||||
|
return controlnet
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid
|
||||||
|
def create_coordinate_grid(size, start=None, device=None):
|
||||||
|
if start is None:
|
||||||
|
start = (0 for _ in size)
|
||||||
|
|
||||||
|
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||||
|
grids = torch.meshgrid(axes, indexing="ij")
|
||||||
|
return torch.stack(grids, dim=-1)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
|
||||||
|
def patchify_and_embed(
|
||||||
|
self,
|
||||||
|
all_image: List[torch.Tensor],
|
||||||
|
all_cap_feats: List[torch.Tensor],
|
||||||
|
patch_size: int,
|
||||||
|
f_patch_size: int,
|
||||||
|
):
|
||||||
|
pH = pW = patch_size
|
||||||
|
pF = f_patch_size
|
||||||
|
device = all_image[0].device
|
||||||
|
|
||||||
|
all_image_out = []
|
||||||
|
all_image_size = []
|
||||||
|
all_image_pos_ids = []
|
||||||
|
all_image_pad_mask = []
|
||||||
|
all_cap_pos_ids = []
|
||||||
|
all_cap_pad_mask = []
|
||||||
|
all_cap_feats_out = []
|
||||||
|
|
||||||
|
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||||
|
### Process Caption
|
||||||
|
cap_ori_len = len(cap_feat)
|
||||||
|
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||||
|
# padded position ids
|
||||||
|
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||||
|
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||||
|
start=(1, 0, 0),
|
||||||
|
device=device,
|
||||||
|
).flatten(0, 2)
|
||||||
|
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||||
|
# pad mask
|
||||||
|
cap_pad_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||||
|
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
all_cap_pad_mask.append(
|
||||||
|
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# padded feature
|
||||||
|
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||||
|
all_cap_feats_out.append(cap_padded_feat)
|
||||||
|
|
||||||
|
### Process Image
|
||||||
|
C, F, H, W = image.size()
|
||||||
|
all_image_size.append((F, H, W))
|
||||||
|
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||||
|
|
||||||
|
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||||
|
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||||
|
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||||
|
|
||||||
|
image_ori_len = len(image)
|
||||||
|
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||||
|
|
||||||
|
image_ori_pos_ids = self.create_coordinate_grid(
|
||||||
|
size=(F_tokens, H_tokens, W_tokens),
|
||||||
|
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||||
|
device=device,
|
||||||
|
).flatten(0, 2)
|
||||||
|
image_padded_pos_ids = torch.cat(
|
||||||
|
[
|
||||||
|
image_ori_pos_ids,
|
||||||
|
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||||
|
.flatten(0, 2)
|
||||||
|
.repeat(image_padding_len, 1),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
||||||
|
# pad mask
|
||||||
|
image_pad_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||||
|
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
all_image_pad_mask.append(
|
||||||
|
image_pad_mask
|
||||||
|
if image_padding_len > 0
|
||||||
|
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
||||||
|
)
|
||||||
|
# padded feature
|
||||||
|
image_padded_feat = torch.cat(
|
||||||
|
[image, image[-1:].repeat(image_padding_len, 1)],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
||||||
|
|
||||||
|
return (
|
||||||
|
all_image_out,
|
||||||
|
all_cap_feats_out,
|
||||||
|
all_image_size,
|
||||||
|
all_image_pos_ids,
|
||||||
|
all_cap_pos_ids,
|
||||||
|
all_image_pad_mask,
|
||||||
|
all_cap_pad_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def patchify(
|
||||||
|
self,
|
||||||
|
all_image: List[torch.Tensor],
|
||||||
|
patch_size: int,
|
||||||
|
f_patch_size: int,
|
||||||
|
):
|
||||||
|
pH = pW = patch_size
|
||||||
|
pF = f_patch_size
|
||||||
|
all_image_out = []
|
||||||
|
|
||||||
|
for i, image in enumerate(all_image):
|
||||||
|
### Process Image
|
||||||
|
C, F, H, W = image.size()
|
||||||
|
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||||
|
|
||||||
|
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||||
|
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||||
|
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||||
|
|
||||||
|
image_ori_len = len(image)
|
||||||
|
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||||
|
|
||||||
|
# padded feature
|
||||||
|
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||||
|
all_image_out.append(image_padded_feat)
|
||||||
|
|
||||||
|
return all_image_out
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: List[torch.Tensor],
|
||||||
|
t,
|
||||||
|
cap_feats: List[torch.Tensor],
|
||||||
|
control_context: List[torch.Tensor],
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
patch_size=2,
|
||||||
|
f_patch_size=1,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
self.t_scale is None
|
||||||
|
or self.t_embedder is None
|
||||||
|
or self.all_x_embedder is None
|
||||||
|
or self.cap_embedder is None
|
||||||
|
or self.rope_embedder is None
|
||||||
|
or self.noise_refiner is None
|
||||||
|
or self.context_refiner is None
|
||||||
|
or self.x_pad_token is None
|
||||||
|
or self.cap_pad_token is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Required modules are `None`, use `from_transformer` to share required modules from `transformer`."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert patch_size in self.config.all_patch_size
|
||||||
|
assert f_patch_size in self.config.all_f_patch_size
|
||||||
|
|
||||||
|
bsz = len(x)
|
||||||
|
device = x[0].device
|
||||||
|
t = t * self.t_scale
|
||||||
|
t = self.t_embedder(t)
|
||||||
|
|
||||||
|
(
|
||||||
|
x,
|
||||||
|
cap_feats,
|
||||||
|
x_size,
|
||||||
|
x_pos_ids,
|
||||||
|
cap_pos_ids,
|
||||||
|
x_inner_pad_mask,
|
||||||
|
cap_inner_pad_mask,
|
||||||
|
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||||
|
|
||||||
|
x_item_seqlens = [len(_) for _ in x]
|
||||||
|
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||||
|
x_max_item_seqlen = max(x_item_seqlens)
|
||||||
|
|
||||||
|
control_context = self.patchify(control_context, patch_size, f_patch_size)
|
||||||
|
control_context = torch.cat(control_context, dim=0)
|
||||||
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
||||||
|
|
||||||
|
control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||||
|
control_context = list(control_context.split(x_item_seqlens, dim=0))
|
||||||
|
|
||||||
|
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
||||||
|
|
||||||
|
# x embed & refine
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||||
|
|
||||||
|
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||||
|
adaln_input = t.type_as(x)
|
||||||
|
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||||
|
x = list(x.split(x_item_seqlens, dim=0))
|
||||||
|
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
||||||
|
|
||||||
|
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||||
|
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||||
|
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
||||||
|
|
||||||
|
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(x_item_seqlens):
|
||||||
|
x_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
|
if self.add_control_noise_refiner is not None:
|
||||||
|
if self.add_control_noise_refiner == "control_layers":
|
||||||
|
layers = self.control_layers
|
||||||
|
elif self.add_control_noise_refiner == "control_noise_refiner":
|
||||||
|
layers = self.control_noise_refiner
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.")
|
||||||
|
for layer in layers:
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
control_context = self._gradient_checkpointing_func(
|
||||||
|
layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||||
|
|
||||||
|
hints = torch.unbind(control_context)[:-1]
|
||||||
|
control_context = torch.unbind(control_context)[-1]
|
||||||
|
noise_refiner_block_samples = {
|
||||||
|
layer_idx: hints[idx] * conditioning_scale
|
||||||
|
for idx, layer_idx in enumerate(self.control_refiner_layers_places)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
noise_refiner_block_samples = None
|
||||||
|
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
for layer_idx, layer in enumerate(self.noise_refiner):
|
||||||
|
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||||
|
if noise_refiner_block_samples is not None:
|
||||||
|
if layer_idx in noise_refiner_block_samples:
|
||||||
|
x = x + noise_refiner_block_samples[layer_idx]
|
||||||
|
else:
|
||||||
|
for layer_idx, layer in enumerate(self.noise_refiner):
|
||||||
|
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||||
|
if noise_refiner_block_samples is not None:
|
||||||
|
if layer_idx in noise_refiner_block_samples:
|
||||||
|
x = x + noise_refiner_block_samples[layer_idx]
|
||||||
|
|
||||||
|
# cap embed & refine
|
||||||
|
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||||
|
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||||
|
|
||||||
|
cap_feats = torch.cat(cap_feats, dim=0)
|
||||||
|
cap_feats = self.cap_embedder(cap_feats)
|
||||||
|
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||||
|
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||||
|
cap_freqs_cis = list(
|
||||||
|
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||||
|
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||||
|
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
|
||||||
|
|
||||||
|
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(cap_item_seqlens):
|
||||||
|
cap_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
for layer in self.context_refiner:
|
||||||
|
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||||
|
else:
|
||||||
|
for layer in self.context_refiner:
|
||||||
|
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||||
|
|
||||||
|
# unified
|
||||||
|
unified = []
|
||||||
|
unified_freqs_cis = []
|
||||||
|
for i in range(bsz):
|
||||||
|
x_len = x_item_seqlens[i]
|
||||||
|
cap_len = cap_item_seqlens[i]
|
||||||
|
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||||
|
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||||
|
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||||
|
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||||
|
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||||
|
|
||||||
|
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||||
|
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(unified_item_seqlens):
|
||||||
|
unified_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
|
## ControlNet start
|
||||||
|
if not self.add_control_noise_refiner:
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = self._gradient_checkpointing_func(
|
||||||
|
layer, control_context, x_attn_mask, x_freqs_cis, adaln_input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input)
|
||||||
|
|
||||||
|
# unified
|
||||||
|
control_context_unified = []
|
||||||
|
for i in range(bsz):
|
||||||
|
x_len = x_item_seqlens[i]
|
||||||
|
cap_len = cap_item_seqlens[i]
|
||||||
|
control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]]))
|
||||||
|
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
||||||
|
|
||||||
|
for layer in self.control_layers:
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
control_context_unified = self._gradient_checkpointing_func(
|
||||||
|
layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
control_context_unified = layer(
|
||||||
|
control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||||
|
)
|
||||||
|
|
||||||
|
hints = torch.unbind(control_context_unified)[:-1]
|
||||||
|
controlnet_block_samples = {
|
||||||
|
layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)
|
||||||
|
}
|
||||||
|
return controlnet_block_samples
|
||||||
@@ -439,6 +439,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||||
concat_padding_mask: bool = True,
|
concat_padding_mask: bool = True,
|
||||||
extra_pos_embed_type: Optional[str] = "learnable",
|
extra_pos_embed_type: Optional[str] = "learnable",
|
||||||
|
use_crossattn_projection: bool = False,
|
||||||
|
crossattn_proj_in_channels: int = 1024,
|
||||||
|
encoder_hidden_states_channels: int = 1024,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = num_attention_heads * attention_head_dim
|
hidden_size = num_attention_heads * attention_head_dim
|
||||||
@@ -485,6 +488,12 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
|
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.use_crossattn_projection:
|
||||||
|
self.crossattn_proj = nn.Sequential(
|
||||||
|
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -524,6 +533,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
post_patch_num_frames = num_frames // p_t
|
post_patch_num_frames = num_frames // p_t
|
||||||
post_patch_height = height // p_h
|
post_patch_height = height // p_h
|
||||||
post_patch_width = width // p_w
|
post_patch_width = width // p_w
|
||||||
|
|
||||||
hidden_states = self.patch_embed(hidden_states)
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
||||||
|
|
||||||
@@ -546,6 +556,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
if self.config.use_crossattn_projection:
|
||||||
|
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
|
||||||
|
|
||||||
# 5. Transformer blocks
|
# 5. Transformer blocks
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks:
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
|||||||
@@ -143,17 +143,26 @@ def apply_rotary_emb_qwen(
|
|||||||
|
|
||||||
|
|
||||||
class QwenTimestepProjEmbeddings(nn.Module):
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_dim):
|
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||||
|
self.use_additional_t_cond = use_additional_t_cond
|
||||||
|
if use_additional_t_cond:
|
||||||
|
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
|
||||||
|
|
||||||
def forward(self, timestep, hidden_states):
|
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
||||||
|
|
||||||
conditioning = timesteps_emb
|
conditioning = timesteps_emb
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
if addition_t_cond is None:
|
||||||
|
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
|
||||||
|
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||||
|
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
|
||||||
|
conditioning = conditioning + addition_t_emb
|
||||||
|
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
@@ -259,6 +268,120 @@ class QwenEmbedRope(nn.Module):
|
|||||||
return freqs.clone().contiguous()
|
return freqs.clone().contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenEmbedLayer3DRope(nn.Module):
|
||||||
|
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
pos_index = torch.arange(4096)
|
||||||
|
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat(
|
||||||
|
[
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
self.neg_freqs = torch.cat(
|
||||||
|
[
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_rope = scale_rope
|
||||||
|
|
||||||
|
def rope_params(self, index, dim, theta=10000):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, video_fhw, txt_seq_lens, device):
|
||||||
|
"""
|
||||||
|
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||||
|
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||||
|
"""
|
||||||
|
if self.pos_freqs.device != device:
|
||||||
|
self.pos_freqs = self.pos_freqs.to(device)
|
||||||
|
self.neg_freqs = self.neg_freqs.to(device)
|
||||||
|
|
||||||
|
if isinstance(video_fhw, list):
|
||||||
|
video_fhw = video_fhw[0]
|
||||||
|
if not isinstance(video_fhw, list):
|
||||||
|
video_fhw = [video_fhw]
|
||||||
|
|
||||||
|
vid_freqs = []
|
||||||
|
max_vid_index = 0
|
||||||
|
layer_num = len(video_fhw) - 1
|
||||||
|
for idx, fhw in enumerate(video_fhw):
|
||||||
|
frame, height, width = fhw
|
||||||
|
if idx != layer_num:
|
||||||
|
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||||
|
else:
|
||||||
|
### For the condition image, we set the layer index to -1
|
||||||
|
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||||
|
video_freq = video_freq.to(device)
|
||||||
|
vid_freqs.append(video_freq)
|
||||||
|
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width, max_vid_index)
|
||||||
|
|
||||||
|
max_vid_index = max(max_vid_index, layer_num)
|
||||||
|
max_len = max(txt_seq_lens)
|
||||||
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||||
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||||
|
|
||||||
|
return vid_freqs, txt_freqs
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
|
||||||
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
return freqs.clone().contiguous()
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _compute_condition_freqs(self, frame, height, width):
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
|
||||||
|
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
return freqs.clone().contiguous()
|
||||||
|
|
||||||
|
|
||||||
class QwenDoubleStreamAttnProcessor2_0:
|
class QwenDoubleStreamAttnProcessor2_0:
|
||||||
"""
|
"""
|
||||||
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
||||||
@@ -578,14 +701,21 @@ class QwenImageTransformer2DModel(
|
|||||||
guidance_embeds: bool = False, # TODO: this should probably be removed
|
guidance_embeds: bool = False, # TODO: this should probably be removed
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
zero_cond_t: bool = False,
|
zero_cond_t: bool = False,
|
||||||
|
use_additional_t_cond: bool = False,
|
||||||
|
use_layer3d_rope: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
if not use_layer3d_rope:
|
||||||
|
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||||
|
else:
|
||||||
|
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||||
|
|
||||||
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
|
embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond
|
||||||
|
)
|
||||||
|
|
||||||
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
||||||
|
|
||||||
@@ -621,6 +751,7 @@ class QwenImageTransformer2DModel(
|
|||||||
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
||||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
controlnet_block_samples=None,
|
controlnet_block_samples=None,
|
||||||
|
additional_t_cond=None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||||
"""
|
"""
|
||||||
@@ -683,9 +814,9 @@ class QwenImageTransformer2DModel(
|
|||||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||||
|
|
||||||
temb = (
|
temb = (
|
||||||
self.time_text_embed(timestep, hidden_states)
|
self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||||
if guidance is None
|
if guidance is None
|
||||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -536,6 +536,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
|||||||
x: List[torch.Tensor],
|
x: List[torch.Tensor],
|
||||||
t,
|
t,
|
||||||
cap_feats: List[torch.Tensor],
|
cap_feats: List[torch.Tensor],
|
||||||
|
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
|
||||||
patch_size=2,
|
patch_size=2,
|
||||||
f_patch_size=1,
|
f_patch_size=1,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
@@ -635,13 +636,19 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
|||||||
unified_attn_mask[i, :seq_len] = 1
|
unified_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
for layer in self.layers:
|
for layer_idx, layer in enumerate(self.layers):
|
||||||
unified = self._gradient_checkpointing_func(
|
unified = self._gradient_checkpointing_func(
|
||||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||||
)
|
)
|
||||||
|
if controlnet_block_samples is not None:
|
||||||
|
if layer_idx in controlnet_block_samples:
|
||||||
|
unified = unified + controlnet_block_samples[layer_idx]
|
||||||
else:
|
else:
|
||||||
for layer in self.layers:
|
for layer_idx, layer in enumerate(self.layers):
|
||||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||||
|
if controlnet_block_samples is not None:
|
||||||
|
if layer_idx in controlnet_block_samples:
|
||||||
|
unified = unified + controlnet_block_samples[layer_idx]
|
||||||
|
|
||||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||||
unified = list(unified.unbind(dim=0))
|
unified = list(unified.unbind(dim=0))
|
||||||
|
|||||||
@@ -360,7 +360,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
|
|||||||
AUTO_BLOCKS = InsertableDict(
|
AUTO_BLOCKS = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", FluxTextEncoderStep()),
|
("text_encoder", FluxTextEncoderStep()),
|
||||||
("image_encoder", FluxAutoVaeEncoderStep()),
|
("vae_encoder", FluxAutoVaeEncoderStep()),
|
||||||
("denoise", FluxCoreDenoiseStep()),
|
("denoise", FluxCoreDenoiseStep()),
|
||||||
("decode", FluxDecodeStep()),
|
("decode", FluxDecodeStep()),
|
||||||
]
|
]
|
||||||
@@ -369,7 +369,7 @@ AUTO_BLOCKS = InsertableDict(
|
|||||||
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", FluxTextEncoderStep()),
|
("text_encoder", FluxTextEncoderStep()),
|
||||||
("image_encoder", FluxKontextAutoVaeEncoderStep()),
|
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||||
("denoise", FluxKontextCoreDenoiseStep()),
|
("denoise", FluxKontextCoreDenoiseStep()),
|
||||||
("decode", FluxDecodeStep()),
|
("decode", FluxDecodeStep()),
|
||||||
]
|
]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -501,15 +501,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_names(self) -> List[str]:
|
def input_names(self) -> List[str]:
|
||||||
return [input_param.name for input_param in self.inputs]
|
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def intermediate_output_names(self) -> List[str]:
|
def intermediate_output_names(self) -> List[str]:
|
||||||
return [output_param.name for output_param in self.intermediate_outputs]
|
return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_names(self) -> List[str]:
|
def output_names(self) -> List[str]:
|
||||||
return [output_param.name for output_param in self.outputs]
|
return [output_param.name for output_param in self.outputs if output_param.name is not None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component_names(self) -> List[str]:
|
||||||
|
return [component.name for component in self.expected_components]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self):
|
def doc(self):
|
||||||
@@ -1525,10 +1529,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
if blocks is None:
|
if blocks is None:
|
||||||
if modular_config_dict is not None:
|
if modular_config_dict is not None:
|
||||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||||
elif config_dict is not None:
|
|
||||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
|
||||||
else:
|
else:
|
||||||
blocks_class_name = None
|
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||||
if blocks_class_name is not None:
|
if blocks_class_name is not None:
|
||||||
diffusers_module = importlib.import_module("diffusers")
|
diffusers_module = importlib.import_module("diffusers")
|
||||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||||
@@ -1625,7 +1627,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
return None, config_dict
|
return None, config_dict
|
||||||
|
|
||||||
except EnvironmentError as e:
|
except EnvironmentError as e:
|
||||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
raise EnvironmentError(
|
||||||
|
f"Failed to load config from '{pretrained_model_name_or_path}'. "
|
||||||
|
f"Could not find or load 'modular_model_index.json' or 'model_index.json'."
|
||||||
|
) from e
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@@ -2550,7 +2555,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
kwargs_type = expected_input_param.kwargs_type
|
kwargs_type = expected_input_param.kwargs_type
|
||||||
if name in passed_kwargs:
|
if name in passed_kwargs:
|
||||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||||
elif name not in state.values:
|
elif kwargs_type is not None and kwargs_type in passed_kwargs:
|
||||||
|
kwargs_dict = passed_kwargs.pop(kwargs_type)
|
||||||
|
for k, v in kwargs_dict.items():
|
||||||
|
state.set(k, v, kwargs_type)
|
||||||
|
elif name is not None and name not in state.values:
|
||||||
state.set(name, default, kwargs_type)
|
state.set(name, default, kwargs_type)
|
||||||
|
|
||||||
# Warn about unexpected inputs
|
# Warn about unexpected inputs
|
||||||
|
|||||||
@@ -30,6 +30,47 @@ from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||||
|
model_name = "qwenimage"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
components = [
|
||||||
|
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(name="height", required=True),
|
||||||
|
InputParam(name="width", required=True),
|
||||||
|
InputParam(
|
||||||
|
name="latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The latents to decode, can be generated in the denoise step",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
vae_scale_factor = components.vae_scale_factor
|
||||||
|
block_state.latents = components.pachifier.unpack_latents(
|
||||||
|
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||||
model_name = "qwenimage"
|
model_name = "qwenimage"
|
||||||
|
|
||||||
@@ -41,7 +82,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
|||||||
def expected_components(self) -> List[ComponentSpec]:
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
components = [
|
components = [
|
||||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return components
|
return components
|
||||||
@@ -49,8 +89,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
|||||||
@property
|
@property
|
||||||
def inputs(self) -> List[InputParam]:
|
def inputs(self) -> List[InputParam]:
|
||||||
return [
|
return [
|
||||||
InputParam(name="height", required=True),
|
|
||||||
InputParam(name="width", required=True),
|
|
||||||
InputParam(
|
InputParam(
|
||||||
name="latents",
|
name="latents",
|
||||||
required=True,
|
required=True,
|
||||||
@@ -74,10 +112,12 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
|||||||
block_state = self.get_block_state(state)
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||||
vae_scale_factor = components.vae_scale_factor
|
if block_state.latents.ndim == 4:
|
||||||
block_state.latents = components.pachifier.unpack_latents(
|
block_state.latents = block_state.latents.unsqueeze(dim=1)
|
||||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
elif block_state.latents.ndim != 5:
|
||||||
)
|
raise ValueError(
|
||||||
|
f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step."
|
||||||
|
)
|
||||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||||
|
|
||||||
latents_mean = (
|
latents_mean = (
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ from .before_denoise import (
|
|||||||
QwenImageSetTimestepsStep,
|
QwenImageSetTimestepsStep,
|
||||||
QwenImageSetTimestepsWithStrengthStep,
|
QwenImageSetTimestepsWithStrengthStep,
|
||||||
)
|
)
|
||||||
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
|
from .decoders import (
|
||||||
|
QwenImageAfterDenoiseStep,
|
||||||
|
QwenImageDecoderStep,
|
||||||
|
QwenImageInpaintProcessImagesOutputStep,
|
||||||
|
QwenImageProcessImagesOutputStep,
|
||||||
|
)
|
||||||
from .denoise import (
|
from .denoise import (
|
||||||
QwenImageControlNetDenoiseStep,
|
QwenImageControlNetDenoiseStep,
|
||||||
QwenImageDenoiseStep,
|
QwenImageDenoiseStep,
|
||||||
@@ -92,6 +97,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
|||||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||||
("denoise", QwenImageDenoiseStep()),
|
("denoise", QwenImageDenoiseStep()),
|
||||||
|
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||||
("decode", QwenImageDecodeStep()),
|
("decode", QwenImageDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -205,6 +211,7 @@ INPAINT_BLOCKS = InsertableDict(
|
|||||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||||
("denoise", QwenImageInpaintDenoiseStep()),
|
("denoise", QwenImageInpaintDenoiseStep()),
|
||||||
|
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||||
("decode", QwenImageInpaintDecodeStep()),
|
("decode", QwenImageInpaintDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -264,6 +271,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
|||||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||||
("denoise", QwenImageDenoiseStep()),
|
("denoise", QwenImageDenoiseStep()),
|
||||||
|
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||||
("decode", QwenImageDecodeStep()),
|
("decode", QwenImageDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -529,8 +537,16 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
|||||||
QwenImageAutoBeforeDenoiseStep,
|
QwenImageAutoBeforeDenoiseStep,
|
||||||
QwenImageOptionalControlNetBeforeDenoiseStep,
|
QwenImageOptionalControlNetBeforeDenoiseStep,
|
||||||
QwenImageAutoDenoiseStep,
|
QwenImageAutoDenoiseStep,
|
||||||
|
QwenImageAfterDenoiseStep,
|
||||||
|
]
|
||||||
|
block_names = [
|
||||||
|
"input",
|
||||||
|
"controlnet_input",
|
||||||
|
"before_denoise",
|
||||||
|
"controlnet_before_denoise",
|
||||||
|
"denoise",
|
||||||
|
"after_denoise",
|
||||||
]
|
]
|
||||||
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@@ -653,6 +669,7 @@ EDIT_BLOCKS = InsertableDict(
|
|||||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||||
("denoise", QwenImageEditDenoiseStep()),
|
("denoise", QwenImageEditDenoiseStep()),
|
||||||
|
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||||
("decode", QwenImageDecodeStep()),
|
("decode", QwenImageDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -702,6 +719,7 @@ EDIT_INPAINT_BLOCKS = InsertableDict(
|
|||||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||||
("denoise", QwenImageEditInpaintDenoiseStep()),
|
("denoise", QwenImageEditInpaintDenoiseStep()),
|
||||||
|
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||||
("decode", QwenImageInpaintDecodeStep()),
|
("decode", QwenImageInpaintDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -841,8 +859,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
|||||||
QwenImageEditAutoInputStep,
|
QwenImageEditAutoInputStep,
|
||||||
QwenImageEditAutoBeforeDenoiseStep,
|
QwenImageEditAutoBeforeDenoiseStep,
|
||||||
QwenImageEditAutoDenoiseStep,
|
QwenImageEditAutoDenoiseStep,
|
||||||
|
QwenImageAfterDenoiseStep,
|
||||||
]
|
]
|
||||||
block_names = ["input", "before_denoise", "denoise"]
|
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@@ -954,6 +973,7 @@ EDIT_PLUS_BLOCKS = InsertableDict(
|
|||||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||||
("denoise", QwenImageEditDenoiseStep()),
|
("denoise", QwenImageEditDenoiseStep()),
|
||||||
|
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||||
("decode", QwenImageDecodeStep()),
|
("decode", QwenImageDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -1037,8 +1057,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
|||||||
QwenImageEditPlusAutoInputStep,
|
QwenImageEditPlusAutoInputStep,
|
||||||
QwenImageEditPlusAutoBeforeDenoiseStep,
|
QwenImageEditPlusAutoBeforeDenoiseStep,
|
||||||
QwenImageEditAutoDenoiseStep,
|
QwenImageEditAutoDenoiseStep,
|
||||||
|
QwenImageAfterDenoiseStep,
|
||||||
]
|
]
|
||||||
block_names = ["input", "before_denoise", "denoise"]
|
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
# mellon nodes
|
|
||||||
QwenImage_NODE_TYPES_PARAMS_MAP = {
|
|
||||||
"controlnet": {
|
|
||||||
"inputs": [
|
|
||||||
"control_image",
|
|
||||||
"controlnet_conditioning_scale",
|
|
||||||
"control_guidance_start",
|
|
||||||
"control_guidance_end",
|
|
||||||
"height",
|
|
||||||
"width",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"controlnet",
|
|
||||||
"vae",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"controlnet_out",
|
|
||||||
],
|
|
||||||
"block_names": ["controlnet_vae_encoder"],
|
|
||||||
},
|
|
||||||
"denoise": {
|
|
||||||
"inputs": [
|
|
||||||
"embeddings",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
"seed",
|
|
||||||
"num_inference_steps",
|
|
||||||
"guidance_scale",
|
|
||||||
"image_latents",
|
|
||||||
"strength",
|
|
||||||
"controlnet",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"unet",
|
|
||||||
"guider",
|
|
||||||
"scheduler",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"latents",
|
|
||||||
"latents_preview",
|
|
||||||
],
|
|
||||||
"block_names": ["denoise"],
|
|
||||||
},
|
|
||||||
"vae_encoder": {
|
|
||||||
"inputs": [
|
|
||||||
"image",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"vae",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"image_latents",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"text_encoder": {
|
|
||||||
"inputs": [
|
|
||||||
"prompt",
|
|
||||||
"negative_prompt",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"text_encoders",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"embeddings",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"decoder": {
|
|
||||||
"inputs": [
|
|
||||||
"latents",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"vae",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"images",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
SDXL_NODE_TYPES_PARAMS_MAP = {
|
|
||||||
"controlnet": {
|
|
||||||
"inputs": [
|
|
||||||
"control_image",
|
|
||||||
"controlnet_conditioning_scale",
|
|
||||||
"control_guidance_start",
|
|
||||||
"control_guidance_end",
|
|
||||||
"height",
|
|
||||||
"width",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"controlnet",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"controlnet_out",
|
|
||||||
],
|
|
||||||
"block_names": [None],
|
|
||||||
},
|
|
||||||
"denoise": {
|
|
||||||
"inputs": [
|
|
||||||
"embeddings",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
"seed",
|
|
||||||
"num_inference_steps",
|
|
||||||
"guidance_scale",
|
|
||||||
"image_latents",
|
|
||||||
"strength",
|
|
||||||
# custom adapters coming in as inputs
|
|
||||||
"controlnet",
|
|
||||||
# ip_adapter is optional and custom; include if available
|
|
||||||
"ip_adapter",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"unet",
|
|
||||||
"guider",
|
|
||||||
"scheduler",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"latents",
|
|
||||||
"latents_preview",
|
|
||||||
],
|
|
||||||
"block_names": ["denoise"],
|
|
||||||
},
|
|
||||||
"vae_encoder": {
|
|
||||||
"inputs": [
|
|
||||||
"image",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"vae",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"image_latents",
|
|
||||||
],
|
|
||||||
"block_names": ["vae_encoder"],
|
|
||||||
},
|
|
||||||
"text_encoder": {
|
|
||||||
"inputs": [
|
|
||||||
"prompt",
|
|
||||||
"negative_prompt",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"text_encoders",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"embeddings",
|
|
||||||
],
|
|
||||||
"block_names": ["text_encoder"],
|
|
||||||
},
|
|
||||||
"decoder": {
|
|
||||||
"inputs": [
|
|
||||||
"latents",
|
|
||||||
],
|
|
||||||
"model_inputs": [
|
|
||||||
"vae",
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
"images",
|
|
||||||
],
|
|
||||||
"block_names": ["decode"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -129,6 +129,10 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
|
|||||||
type_hint=int,
|
type_hint=int,
|
||||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
),
|
),
|
||||||
|
InputParam(
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
guider_input_names = []
|
guider_input_names = []
|
||||||
uncond_guider_input_names = []
|
uncond_guider_input_names = []
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
|||||||
|
|
||||||
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||||
block_classes = [ZImageVaeImageEncoderStep]
|
block_classes = [ZImageVaeImageEncoderStep]
|
||||||
block_names = ["vae_image_encoder"]
|
block_names = ["vae_encoder"]
|
||||||
block_trigger_inputs = ["image"]
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -137,7 +137,7 @@ class ZImageAutoBlocks(SequentialPipelineBlocks):
|
|||||||
ZImageAutoDenoiseStep,
|
ZImageAutoDenoiseStep,
|
||||||
ZImageVaeDecoderStep,
|
ZImageVaeDecoderStep,
|
||||||
]
|
]
|
||||||
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
|
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
@@ -162,7 +162,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
|||||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", ZImageTextEncoderStep),
|
("text_encoder", ZImageTextEncoderStep),
|
||||||
("vae_image_encoder", ZImageVaeImageEncoderStep),
|
("vae_encoder", ZImageVaeImageEncoderStep),
|
||||||
("input", ZImageTextInputStep),
|
("input", ZImageTextInputStep),
|
||||||
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
||||||
("prepare_latents", ZImagePrepareLatentsStep),
|
("prepare_latents", ZImagePrepareLatentsStep),
|
||||||
@@ -178,7 +178,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
|||||||
AUTO_BLOCKS = InsertableDict(
|
AUTO_BLOCKS = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", ZImageTextEncoderStep),
|
("text_encoder", ZImageTextEncoderStep),
|
||||||
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
|
("vae_encoder", ZImageAutoVaeImageEncoderStep),
|
||||||
("denoise", ZImageAutoDenoiseStep),
|
("denoise", ZImageAutoDenoiseStep),
|
||||||
("decode", ZImageVaeDecoderStep),
|
("decode", ZImageVaeDecoderStep),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ else:
|
|||||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||||
_import_structure["cosmos"] = [
|
_import_structure["cosmos"] = [
|
||||||
|
"Cosmos2_5_PredictBasePipeline",
|
||||||
"Cosmos2TextToImagePipeline",
|
"Cosmos2TextToImagePipeline",
|
||||||
"CosmosTextToWorldPipeline",
|
"CosmosTextToWorldPipeline",
|
||||||
"CosmosVideoToWorldPipeline",
|
"CosmosVideoToWorldPipeline",
|
||||||
@@ -405,7 +406,12 @@ else:
|
|||||||
"Kandinsky5T2IPipeline",
|
"Kandinsky5T2IPipeline",
|
||||||
"Kandinsky5I2IPipeline",
|
"Kandinsky5I2IPipeline",
|
||||||
]
|
]
|
||||||
_import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline"]
|
_import_structure["z_image"] = [
|
||||||
|
"ZImageImg2ImgPipeline",
|
||||||
|
"ZImagePipeline",
|
||||||
|
"ZImageControlNetPipeline",
|
||||||
|
"ZImageControlNetInpaintPipeline",
|
||||||
|
]
|
||||||
_import_structure["skyreels_v2"] = [
|
_import_structure["skyreels_v2"] = [
|
||||||
"SkyReelsV2DiffusionForcingPipeline",
|
"SkyReelsV2DiffusionForcingPipeline",
|
||||||
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
|
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
|
||||||
@@ -422,6 +428,7 @@ else:
|
|||||||
"QwenImageEditInpaintPipeline",
|
"QwenImageEditInpaintPipeline",
|
||||||
"QwenImageControlNetInpaintPipeline",
|
"QwenImageControlNetInpaintPipeline",
|
||||||
"QwenImageControlNetPipeline",
|
"QwenImageControlNetPipeline",
|
||||||
|
"QwenImageLayeredPipeline",
|
||||||
]
|
]
|
||||||
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
|
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
|
||||||
try:
|
try:
|
||||||
@@ -616,6 +623,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLControlNetXSPipeline,
|
StableDiffusionXLControlNetXSPipeline,
|
||||||
)
|
)
|
||||||
from .cosmos import (
|
from .cosmos import (
|
||||||
|
Cosmos2_5_PredictBasePipeline,
|
||||||
Cosmos2TextToImagePipeline,
|
Cosmos2TextToImagePipeline,
|
||||||
Cosmos2VideoToWorldPipeline,
|
Cosmos2VideoToWorldPipeline,
|
||||||
CosmosTextToWorldPipeline,
|
CosmosTextToWorldPipeline,
|
||||||
@@ -764,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
QwenImageEditPlusPipeline,
|
QwenImageEditPlusPipeline,
|
||||||
QwenImageImg2ImgPipeline,
|
QwenImageImg2ImgPipeline,
|
||||||
QwenImageInpaintPipeline,
|
QwenImageInpaintPipeline,
|
||||||
|
QwenImageLayeredPipeline,
|
||||||
QwenImagePipeline,
|
QwenImagePipeline,
|
||||||
)
|
)
|
||||||
from .sana import (
|
from .sana import (
|
||||||
@@ -843,7 +852,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
WuerstchenDecoderPipeline,
|
WuerstchenDecoderPipeline,
|
||||||
WuerstchenPriorPipeline,
|
WuerstchenPriorPipeline,
|
||||||
)
|
)
|
||||||
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
|
from .z_image import (
|
||||||
|
ZImageControlNetInpaintPipeline,
|
||||||
|
ZImageControlNetPipeline,
|
||||||
|
ZImageImg2ImgPipeline,
|
||||||
|
ZImagePipeline,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_onnx_available():
|
if not is_onnx_available():
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ except OptionalDependencyNotAvailable:
|
|||||||
|
|
||||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
else:
|
else:
|
||||||
|
_import_structure["pipeline_cosmos2_5_predict"] = [
|
||||||
|
"Cosmos2_5_PredictBasePipeline",
|
||||||
|
]
|
||||||
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
|
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
|
||||||
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
|
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
|
||||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||||
@@ -35,6 +38,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
from ...utils.dummy_torch_and_transformers_objects import *
|
from ...utils.dummy_torch_and_transformers_objects import *
|
||||||
else:
|
else:
|
||||||
|
from .pipeline_cosmos2_5_predict import (
|
||||||
|
Cosmos2_5_PredictBasePipeline,
|
||||||
|
)
|
||||||
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
|
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
|
||||||
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
|
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
|
||||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||||
|
|||||||
847
src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
Normal file
847
src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
Normal file
@@ -0,0 +1,847 @@
|
|||||||
|
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms
|
||||||
|
import torchvision.transforms.functional
|
||||||
|
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
|
from ...image_processor import PipelineImageInput
|
||||||
|
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||||
|
from ...schedulers import UniPCMultistepScheduler
|
||||||
|
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ...video_processor import VideoProcessor
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
|
from .pipeline_output import CosmosPipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
if is_cosmos_guardrail_available():
|
||||||
|
from cosmos_guardrail import CosmosSafetyChecker
|
||||||
|
else:
|
||||||
|
|
||||||
|
class CosmosSafetyChecker:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise ImportError(
|
||||||
|
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_xla_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
XLA_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
XLA_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from diffusers import Cosmos2_5_PredictBasePipeline
|
||||||
|
>>> from diffusers.utils import export_to_video, load_image, load_video
|
||||||
|
|
||||||
|
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||||
|
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||||
|
... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
|
||||||
|
... )
|
||||||
|
>>> pipe = pipe.to("cuda")
|
||||||
|
|
||||||
|
>>> # Common negative prompt reused across modes.
|
||||||
|
>>> negative_prompt = (
|
||||||
|
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||||
|
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||||
|
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
|
||||||
|
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||||
|
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||||
|
... "Overall, the video is of poor quality."
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Text2World: generate a 93-frame world video from text only.
|
||||||
|
>>> prompt = (
|
||||||
|
... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights "
|
||||||
|
... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh "
|
||||||
|
... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet "
|
||||||
|
... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. "
|
||||||
|
... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow "
|
||||||
|
... "advance of traffic through the frosty city corridor."
|
||||||
|
... )
|
||||||
|
>>> video = pipe(
|
||||||
|
... image=None,
|
||||||
|
... video=None,
|
||||||
|
... prompt=prompt,
|
||||||
|
... negative_prompt=negative_prompt,
|
||||||
|
... num_frames=93,
|
||||||
|
... generator=torch.Generator().manual_seed(1),
|
||||||
|
... ).frames[0]
|
||||||
|
>>> export_to_video(video, "text2world.mp4", fps=16)
|
||||||
|
|
||||||
|
>>> # Image2World: condition on a single image and generate a 93-frame world video.
|
||||||
|
>>> prompt = (
|
||||||
|
... "A high-definition video captures the precision of robotic welding in an industrial setting. "
|
||||||
|
... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. "
|
||||||
|
... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid "
|
||||||
|
... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring "
|
||||||
|
... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a "
|
||||||
|
... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video "
|
||||||
|
... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. "
|
||||||
|
... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. "
|
||||||
|
... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with "
|
||||||
|
... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation."
|
||||||
|
... )
|
||||||
|
>>> image = load_image(
|
||||||
|
... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg"
|
||||||
|
... )
|
||||||
|
>>> video = pipe(
|
||||||
|
... image=image,
|
||||||
|
... video=None,
|
||||||
|
... prompt=prompt,
|
||||||
|
... negative_prompt=negative_prompt,
|
||||||
|
... num_frames=93,
|
||||||
|
... generator=torch.Generator().manual_seed(1),
|
||||||
|
... ).frames[0]
|
||||||
|
>>> # export_to_video(video, "image2world.mp4", fps=16)
|
||||||
|
|
||||||
|
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
|
||||||
|
>>> prompt = (
|
||||||
|
... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles "
|
||||||
|
... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the "
|
||||||
|
... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green "
|
||||||
|
... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. "
|
||||||
|
... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along "
|
||||||
|
... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame "
|
||||||
|
... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet "
|
||||||
|
... "steady pace of the construction activity."
|
||||||
|
... )
|
||||||
|
>>> input_video = load_video(
|
||||||
|
... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4"
|
||||||
|
... )
|
||||||
|
>>> video = pipe(
|
||||||
|
... image=None,
|
||||||
|
... video=input_video,
|
||||||
|
... prompt=prompt,
|
||||||
|
... negative_prompt=negative_prompt,
|
||||||
|
... num_frames=93,
|
||||||
|
... generator=torch.Generator().manual_seed(1),
|
||||||
|
... ).frames[0]
|
||||||
|
>>> export_to_video(video, "video2world.mp4", fps=16)
|
||||||
|
|
||||||
|
>>> # To produce an image instead of a world (video) clip, set num_frames=1 and
|
||||||
|
>>> # save the first frame: pipe(..., num_frames=1).frames[0][0].
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model.
|
||||||
|
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||||
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
|
||||||
|
Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5
|
||||||
|
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
|
||||||
|
tokenizer (`AutoTokenizer`):
|
||||||
|
Tokenizer associated with the Qwen2.5 VL encoder.
|
||||||
|
transformer ([`CosmosTransformer3DModel`]):
|
||||||
|
Conditional Transformer to denoise the encoded image latents.
|
||||||
|
scheduler ([`UniPCMultistepScheduler`]):
|
||||||
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||||
|
vae ([`AutoencoderKLWan`]):
|
||||||
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||||
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||||
|
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||||
|
_optional_components = ["safety_checker"]
|
||||||
|
_exclude_from_cpu_offload = ["safety_checker"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
transformer: CosmosTransformer3DModel,
|
||||||
|
vae: AutoencoderKLWan,
|
||||||
|
scheduler: UniPCMultistepScheduler,
|
||||||
|
safety_checker: CosmosSafetyChecker = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
safety_checker = CosmosSafetyChecker()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
transformer=transformer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=safety_checker,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||||
|
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||||
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
|
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||||
|
if getattr(self.vae.config, "latents_mean", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
latents_std = (
|
||||||
|
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||||
|
if getattr(self.vae.config, "latents_std", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.latents_mean = latents_mean
|
||||||
|
self.latents_std = latents_std
|
||||||
|
|
||||||
|
if self.latents_mean is None or self.latents_std is None:
|
||||||
|
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
|
||||||
|
|
||||||
|
def _get_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
device = device or self._execution_device
|
||||||
|
dtype = dtype or self.text_encoder.dtype
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
input_ids_batch = []
|
||||||
|
|
||||||
|
for sample_idx in range(len(prompt)):
|
||||||
|
conversations = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are a helpful assistant who will provide prompts to an image generator.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt[sample_idx],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
|
conversations,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
add_vision_id=False,
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
padding="max_length",
|
||||||
|
)
|
||||||
|
input_ids = torch.LongTensor(input_ids)
|
||||||
|
input_ids_batch.append(input_ids)
|
||||||
|
|
||||||
|
input_ids_batch = torch.stack(input_ids_batch, dim=0)
|
||||||
|
|
||||||
|
outputs = self.text_encoder(
|
||||||
|
input_ids_batch.to(device),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
hidden_states = outputs.hidden_states
|
||||||
|
|
||||||
|
normalized_hidden_states = []
|
||||||
|
for layer_idx in range(1, len(hidden_states)):
|
||||||
|
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
|
||||||
|
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
|
||||||
|
)
|
||||||
|
normalized_hidden_states.append(normalized_state)
|
||||||
|
|
||||||
|
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
num_videos_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use classifier free guidance or not.
|
||||||
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
device: (`torch.device`, *optional*):
|
||||||
|
torch device
|
||||||
|
dtype: (`torch.dtype`, *optional*):
|
||||||
|
torch dtype
|
||||||
|
"""
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
if prompt is not None:
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds = self._get_prompt_embeds(
|
||||||
|
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
_, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
|
negative_prompt = negative_prompt or ""
|
||||||
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
|
||||||
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_prompt_embeds = self._get_prompt_embeds(
|
||||||
|
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
_, seq_len, _ = negative_prompt_embeds.shape
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
|
||||||
|
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
video: Optional[torch.Tensor],
|
||||||
|
batch_size: int,
|
||||||
|
num_channels_latents: int = 16,
|
||||||
|
height: int = 704,
|
||||||
|
width: int = 1280,
|
||||||
|
num_frames_in: int = 93,
|
||||||
|
num_frames_out: int = 93,
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
B = batch_size
|
||||||
|
C = num_channels_latents
|
||||||
|
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||||
|
H = height // self.vae_scale_factor_spatial
|
||||||
|
W = width // self.vae_scale_factor_spatial
|
||||||
|
shape = (B, C, T, H, W)
|
||||||
|
|
||||||
|
if num_frames_in == 0:
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||||
|
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||||
|
|
||||||
|
cond_latents = torch.zeros_like(latents)
|
||||||
|
|
||||||
|
return (
|
||||||
|
latents,
|
||||||
|
cond_latents,
|
||||||
|
cond_mask,
|
||||||
|
cond_indicator,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if video is None:
|
||||||
|
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
|
||||||
|
needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3)
|
||||||
|
if needs_preprocessing:
|
||||||
|
video = self.video_processor.preprocess_video(video, height, width)
|
||||||
|
video = video.to(device=device, dtype=self.vae.dtype)
|
||||||
|
if isinstance(generator, list):
|
||||||
|
cond_latents = [
|
||||||
|
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||||
|
|
||||||
|
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
|
||||||
|
|
||||||
|
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||||
|
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||||
|
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
padding_shape = (B, 1, T, H, W)
|
||||||
|
ones_padding = latents.new_ones(padding_shape)
|
||||||
|
zeros_padding = latents.new_zeros(padding_shape)
|
||||||
|
|
||||||
|
num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1
|
||||||
|
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||||
|
cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
|
||||||
|
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||||
|
|
||||||
|
return (
|
||||||
|
latents,
|
||||||
|
cond_latents,
|
||||||
|
cond_mask,
|
||||||
|
cond_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||||
|
def check_inputs(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds=None,
|
||||||
|
callback_on_step_end_tensor_inputs=None,
|
||||||
|
):
|
||||||
|
if height % 16 != 0 or width % 16 != 0:
|
||||||
|
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||||
|
|
||||||
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||||
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||||
|
" only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is None and prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guidance_scale(self):
|
||||||
|
return self._guidance_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_classifier_free_guidance(self):
|
||||||
|
return self._guidance_scale > 1.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_timesteps(self):
|
||||||
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_timestep(self):
|
||||||
|
return self._current_timestep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interrupt(self):
|
||||||
|
return self._interrupt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: PipelineImageInput | None = None,
|
||||||
|
video: List[PipelineImageInput] | None = None,
|
||||||
|
prompt: Union[str, List[str]] | None = None,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
height: int = 704,
|
||||||
|
width: int = 1280,
|
||||||
|
num_frames: int = 93,
|
||||||
|
num_inference_steps: int = 36,
|
||||||
|
guidance_scale: float = 7.0,
|
||||||
|
num_videos_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback_on_step_end: Optional[
|
||||||
|
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||||
|
] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
conditional_frame_timestep: float = 0.1,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
The call function to the pipeline for generation. Supports three modes:
|
||||||
|
|
||||||
|
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
||||||
|
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
||||||
|
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
||||||
|
|
||||||
|
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
||||||
|
above in "*2Image mode").
|
||||||
|
|
||||||
|
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||||
|
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
||||||
|
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||||
|
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||||
|
height (`int`, defaults to `704`):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, defaults to `1280`):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_frames (`int`, defaults to `93`):
|
||||||
|
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
||||||
|
num_inference_steps (`int`, defaults to `35`):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, defaults to `7.0`):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion
|
||||||
|
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||||
|
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||||
|
`guidance_scale > 1`.
|
||||||
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||||
|
generation deterministic.
|
||||||
|
latents (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor is generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||||
|
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||||
|
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||||
|
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||||
|
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||||
|
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||||
|
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||||
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||||
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
max_sequence_length (`int`, defaults to `512`):
|
||||||
|
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||||
|
the prompt is shorter than this length, it will be padded.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~CosmosPipelineOutput`] or `tuple`:
|
||||||
|
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||||
|
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||||
|
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||||
|
"""
|
||||||
|
if self.safety_checker is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||||
|
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||||
|
f"Please ensure that you are compliant with the license agreement."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||||
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||||
|
|
||||||
|
# Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._current_timestep = None
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
self.safety_checker.to(device)
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
for p in prompt_list:
|
||||||
|
if not self.safety_checker.check_text_safety(p):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||||
|
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
# Encode input prompt
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
vae_dtype = self.vae.dtype
|
||||||
|
transformer_dtype = self.transformer.dtype
|
||||||
|
|
||||||
|
num_frames_in = None
|
||||||
|
if image is not None:
|
||||||
|
if batch_size != 1:
|
||||||
|
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
||||||
|
|
||||||
|
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
||||||
|
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
||||||
|
video = video.unsqueeze(0)
|
||||||
|
num_frames_in = 1
|
||||||
|
elif video is None:
|
||||||
|
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||||
|
num_frames_in = 0
|
||||||
|
else:
|
||||||
|
num_frames_in = len(video)
|
||||||
|
|
||||||
|
if batch_size != 1:
|
||||||
|
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||||
|
|
||||||
|
assert video is not None
|
||||||
|
video = self.video_processor.preprocess_video(video, height, width)
|
||||||
|
|
||||||
|
# pad with last frame (for video2world)
|
||||||
|
num_frames_out = num_frames
|
||||||
|
if video.shape[2] < num_frames_out:
|
||||||
|
n_pad_frames = num_frames_out - num_frames_in
|
||||||
|
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
|
||||||
|
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
|
||||||
|
video = torch.cat((video, pad_frames), dim=2)
|
||||||
|
|
||||||
|
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
||||||
|
|
||||||
|
video = video.to(device=device, dtype=vae_dtype)
|
||||||
|
|
||||||
|
num_channels_latents = self.transformer.config.in_channels - 1
|
||||||
|
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||||
|
video=video,
|
||||||
|
batch_size=batch_size * num_videos_per_prompt,
|
||||||
|
num_channels_latents=num_channels_latents,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames_in=num_frames_in,
|
||||||
|
num_frames_out=num_frames,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
generator=generator,
|
||||||
|
latents=latents,
|
||||||
|
)
|
||||||
|
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||||
|
cond_mask = cond_mask.to(transformer_dtype)
|
||||||
|
|
||||||
|
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||||
|
|
||||||
|
# Denoising loop
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
|
||||||
|
gt_velocity = (latents - cond_latent) * cond_mask
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t.cpu().item()
|
||||||
|
|
||||||
|
# NOTE: assumes sigma(t) \in [0, 1]
|
||||||
|
sigma_t = (
|
||||||
|
torch.tensor(self.scheduler.sigmas[i].item())
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(device=device, dtype=transformer_dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||||
|
in_latents = in_latents.to(transformer_dtype)
|
||||||
|
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=in_latents,
|
||||||
|
condition_mask=cond_mask,
|
||||||
|
timestep=in_timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
|
||||||
|
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||||
|
|
||||||
|
if self.do_classifier_free_guidance:
|
||||||
|
noise_pred_neg = self.transformer(
|
||||||
|
hidden_states=in_latents,
|
||||||
|
condition_mask=cond_mask,
|
||||||
|
timestep=in_timestep,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||||
|
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||||
|
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||||
|
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
self._current_timestep = None
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
||||||
|
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
||||||
|
latents = latents * latents_std + latents_mean
|
||||||
|
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||||
|
video = self._match_num_frames(video, num_frames)
|
||||||
|
|
||||||
|
assert self.safety_checker is not None
|
||||||
|
self.safety_checker.to(device)
|
||||||
|
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||||
|
video = (video * 255).astype(np.uint8)
|
||||||
|
video_batch = []
|
||||||
|
for vid in video:
|
||||||
|
vid = self.safety_checker.check_video_safety(vid)
|
||||||
|
video_batch.append(vid)
|
||||||
|
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||||
|
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||||
|
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||||
|
else:
|
||||||
|
video = latents
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (video,)
|
||||||
|
|
||||||
|
return CosmosPipelineOutput(frames=video)
|
||||||
|
|
||||||
|
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
||||||
|
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
||||||
|
return video
|
||||||
|
|
||||||
|
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
||||||
|
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
||||||
|
|
||||||
|
current_frames = video.shape[2]
|
||||||
|
if current_frames < target_num_frames:
|
||||||
|
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
||||||
|
video = torch.cat([video, pad], dim=2)
|
||||||
|
elif current_frames > target_num_frames:
|
||||||
|
video = video[:, :, :target_num_frames]
|
||||||
|
|
||||||
|
return video
|
||||||
@@ -31,6 +31,7 @@ else:
|
|||||||
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
|
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
|
||||||
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
|
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
|
||||||
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
|
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
|
||||||
|
_import_structure["pipeline_qwenimage_layered"] = ["QwenImageLayeredPipeline"]
|
||||||
|
|
||||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
try:
|
try:
|
||||||
@@ -47,6 +48,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
||||||
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
|
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
|
||||||
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
|
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
|
||||||
|
from .pipeline_qwenimage_layered import QwenImageLayeredPipeline
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
905
src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
Normal file
905
src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
Normal file
@@ -0,0 +1,905 @@
|
|||||||
|
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import math
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||||
|
|
||||||
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
|
from ...loaders import QwenImageLoraLoaderMixin
|
||||||
|
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
|
from .pipeline_output import QwenImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_xla_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
XLA_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
XLA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```py
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> from diffusers import QwenImageLayeredPipeline
|
||||||
|
>>> from diffusers.utils import load_image
|
||||||
|
|
||||||
|
>>> pipe = QwenImageLayeredPipeline.from_pretrained("Qwen/Qwen-Image-Layered", torch_dtype=torch.bfloat16)
|
||||||
|
>>> pipe.to("cuda")
|
||||||
|
>>> image = load_image(
|
||||||
|
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
||||||
|
... ).convert("RGBA")
|
||||||
|
>>> prompt = ""
|
||||||
|
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||||
|
>>> # Refer to the pipeline documentation for more details.
|
||||||
|
>>> images = pipe(
|
||||||
|
... image,
|
||||||
|
... prompt,
|
||||||
|
... num_inference_steps=50,
|
||||||
|
... true_cfg_scale=4.0,
|
||||||
|
... layers=4,
|
||||||
|
... resolution=640,
|
||||||
|
... cfg_normalize=False,
|
||||||
|
... use_en_prompt=True,
|
||||||
|
... ).images[0]
|
||||||
|
>>> for i, image in enumerate(images):
|
||||||
|
... image.save(f"{i}.out.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||||
|
def calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
base_seq_len: int = 256,
|
||||||
|
max_seq_len: int = 4096,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions
|
||||||
|
def calculate_dimensions(target_area, ratio):
|
||||||
|
width = math.sqrt(target_area * ratio)
|
||||||
|
height = width / ratio
|
||||||
|
|
||||||
|
width = round(width / 32) * 32
|
||||||
|
height = round(height / 32) * 32
|
||||||
|
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||||
|
r"""
|
||||||
|
The Qwen-Image-Layered pipeline for image decomposing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transformer ([`QwenImageTransformer2DModel`]):
|
||||||
|
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||||
|
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||||
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||||
|
vae ([`AutoencoderKL`]):
|
||||||
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||||
|
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||||
|
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||||
|
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||||
|
tokenizer (`QwenTokenizer`):
|
||||||
|
Tokenizer of class
|
||||||
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||||
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
vae: AutoencoderKLQwenImage,
|
||||||
|
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||||
|
tokenizer: Qwen2Tokenizer,
|
||||||
|
processor: Qwen2VLProcessor,
|
||||||
|
transformer: QwenImageTransformer2DModel,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
|
transformer=transformer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||||
|
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
||||||
|
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||||
|
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||||
|
self.vl_processor = processor
|
||||||
|
self.tokenizer_max_length = 1024
|
||||||
|
|
||||||
|
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.prompt_template_encode_start_idx = 34
|
||||||
|
self.image_caption_prompt_cn = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# 图像标注器\n你是一个专业的图像标注器。请基于输入图像,撰写图注:\n1.
|
||||||
|
使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n2. 通过加入以下内容,丰富图注细节:\n - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n -
|
||||||
|
对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n - 环境细节:例如天气、光照、颜色、纹理、气氛等\n - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n3.
|
||||||
|
保持真实性与准确性:\n - 不要使用笼统的描述\n -
|
||||||
|
描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n"""
|
||||||
|
self.image_caption_prompt_en = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# Image Annotator\nYou are a professional
|
||||||
|
image annotator. Please write an image caption based on the input image:\n1. Write the caption using natural,
|
||||||
|
descriptive language without structured formats or rich text.\n2. Enrich caption details by including: \n - Object
|
||||||
|
attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n - Vision Relations
|
||||||
|
between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action
|
||||||
|
relations, comparative relations, causal relations, and so on\n - Environmental details, such as weather, lighting,
|
||||||
|
colors, textures, atmosphere, and so on\n - Identify the text clearly visible in the image, without translation or
|
||||||
|
explanation, and highlight it in the caption with quotation marks\n3. Maintain authenticity and accuracy:\n - Avoid
|
||||||
|
generalizations\n - Describe all visible information in the image, while do not add information not explicitly shown in
|
||||||
|
the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n"""
|
||||||
|
self.default_sample_size = 128
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
||||||
|
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||||
|
bool_mask = mask.bool()
|
||||||
|
valid_lengths = bool_mask.sum(dim=1)
|
||||||
|
selected = hidden_states[bool_mask]
|
||||||
|
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||||
|
|
||||||
|
return split_result
|
||||||
|
|
||||||
|
def _get_qwen_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
device = device or self._execution_device
|
||||||
|
dtype = dtype or self.text_encoder.dtype
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
template = self.prompt_template_encode
|
||||||
|
drop_idx = self.prompt_template_encode_start_idx
|
||||||
|
txt = [template.format(e) for e in prompt]
|
||||||
|
txt_tokens = self.tokenizer(
|
||||||
|
txt,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(device)
|
||||||
|
encoder_hidden_states = self.text_encoder(
|
||||||
|
input_ids=txt_tokens.input_ids,
|
||||||
|
attention_mask=txt_tokens.attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
hidden_states = encoder_hidden_states.hidden_states[-1]
|
||||||
|
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||||
|
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||||
|
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||||
|
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||||
|
prompt_embeds = torch.stack(
|
||||||
|
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||||
|
)
|
||||||
|
encoder_attention_mask = torch.stack(
|
||||||
|
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return prompt_embeds, encoder_attention_mask
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
max_sequence_length: int = 1024,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
num_images_per_prompt (`int`):
|
||||||
|
number of images that should be generated per prompt
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
"""
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
||||||
|
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||||
|
|
||||||
|
_, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||||
|
|
||||||
|
return prompt_embeds, prompt_embeds_mask
|
||||||
|
|
||||||
|
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):
|
||||||
|
if use_en_prompt:
|
||||||
|
prompt = self.image_caption_prompt_en
|
||||||
|
else:
|
||||||
|
prompt = self.image_caption_prompt_cn
|
||||||
|
model_inputs = self.vl_processor(
|
||||||
|
text=prompt,
|
||||||
|
images=prompt_image,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(device)
|
||||||
|
generated_ids = self.text_encoder.generate(**model_inputs, max_new_tokens=512)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
output_text = self.vl_processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
return output_text.strip()
|
||||||
|
|
||||||
|
def check_inputs(
|
||||||
|
self,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
negative_prompt_embeds=None,
|
||||||
|
prompt_embeds_mask=None,
|
||||||
|
negative_prompt_embeds_mask=None,
|
||||||
|
callback_on_step_end_tensor_inputs=None,
|
||||||
|
max_sequence_length=None,
|
||||||
|
):
|
||||||
|
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||||
|
logger.warning(
|
||||||
|
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||||
|
)
|
||||||
|
|
||||||
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||||
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||||
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||||
|
)
|
||||||
|
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||||
|
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers):
|
||||||
|
latents = latents.view(batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
|
latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
|
||||||
|
latents = latents.reshape(batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpack_latents(latents, height, width, layers, vae_scale_factor):
|
||||||
|
batch_size, num_patches, channels = latents.shape
|
||||||
|
|
||||||
|
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||||
|
# latent height and width to be divisible by 2.
|
||||||
|
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||||
|
|
||||||
|
latents = latents.view(batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2)
|
||||||
|
latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
|
||||||
|
|
||||||
|
latents = latents.reshape(batch_size, layers + 1, channels // (2 * 2), height, width)
|
||||||
|
latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
|
||||||
|
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||||
|
if isinstance(generator, list):
|
||||||
|
image_latents = [
|
||||||
|
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
||||||
|
for i in range(image.shape[0])
|
||||||
|
]
|
||||||
|
image_latents = torch.cat(image_latents, dim=0)
|
||||||
|
else:
|
||||||
|
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(self.vae.config.latents_mean)
|
||||||
|
.view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
.to(image_latents.device, image_latents.dtype)
|
||||||
|
)
|
||||||
|
latents_std = (
|
||||||
|
torch.tensor(self.vae.config.latents_std)
|
||||||
|
.view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
.to(image_latents.device, image_latents.dtype)
|
||||||
|
)
|
||||||
|
image_latents = (image_latents - latents_mean) / latents_std
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
layers,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||||
|
# latent height and width to be divisible by 2.
|
||||||
|
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||||
|
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
layers + 1,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
) ### the generated first image is combined image
|
||||||
|
|
||||||
|
image_latents = None
|
||||||
|
if image is not None:
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
if image.shape[1] != self.latent_channels:
|
||||||
|
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||||
|
else:
|
||||||
|
image_latents = image
|
||||||
|
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||||
|
# expand init_latents for batch_size
|
||||||
|
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||||
|
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||||
|
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_latents = torch.cat([image_latents], dim=0)
|
||||||
|
|
||||||
|
image_latent_height, image_latent_width = image_latents.shape[3:]
|
||||||
|
image_latents = image_latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) -> (b, f, c, h, w)
|
||||||
|
image_latents = self._pack_latents(
|
||||||
|
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, layers + 1)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return latents, image_latents
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guidance_scale(self):
|
||||||
|
return self._guidance_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attention_kwargs(self):
|
||||||
|
return self._attention_kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_timesteps(self):
|
||||||
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_timestep(self):
|
||||||
|
return self._current_timestep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interrupt(self):
|
||||||
|
return self._interrupt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Optional[PipelineImageInput] = None,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
negative_prompt: Union[str, List[str]] = None,
|
||||||
|
true_cfg_scale: float = 4.0,
|
||||||
|
layers: Optional[int] = 4,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
guidance_scale: Optional[float] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
resolution: int = 640,
|
||||||
|
cfg_normalize: bool = False,
|
||||||
|
use_en_prompt: bool = False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||||
|
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||||
|
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||||
|
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||||
|
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||||
|
latents as `image`, but if passing latents directly it is not encoded again.
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||||
|
not greater than `1`).
|
||||||
|
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||||
|
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
|
||||||
|
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
|
||||||
|
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
|
||||||
|
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
|
||||||
|
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
|
||||||
|
lower image quality.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||||
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
|
will be used.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to None):
|
||||||
|
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||||
|
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||||
|
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||||
|
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
|
||||||
|
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
|
||||||
|
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
|
||||||
|
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
|
||||||
|
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
|
||||||
|
enable classifier-free guidance computations).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will be generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
callback_on_step_end (`Callable`, *optional*):
|
||||||
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||||
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||||
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||||
|
`callback_on_step_end_tensor_inputs`.
|
||||||
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||||
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||||
|
resolution (`int`, *optional*, defaults to 640):
|
||||||
|
using different bucket in (640, 1024) to determin the condition and output resolution
|
||||||
|
cfg_normalize (`bool`, *optional*, defaults to `False`)
|
||||||
|
whether enable cfg normalization.
|
||||||
|
use_en_prompt (`bool`, *optional*, defaults to `False`)
|
||||||
|
automatic caption language if user does not provide caption
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||||
|
returning a tuple, the first element is a list with the generated images.
|
||||||
|
"""
|
||||||
|
image_size = image[0].size if isinstance(image, list) else image.size
|
||||||
|
assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}"
|
||||||
|
calculated_width, calculated_height = calculate_dimensions(
|
||||||
|
resolution * resolution, image_size[0] / image_size[1]
|
||||||
|
)
|
||||||
|
height = calculated_height
|
||||||
|
width = calculated_width
|
||||||
|
|
||||||
|
multiple_of = self.vae_scale_factor * 2
|
||||||
|
width = width // multiple_of * multiple_of
|
||||||
|
height = height // multiple_of * multiple_of
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_embeds_mask=prompt_embeds_mask,
|
||||||
|
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||||
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._attention_kwargs = attention_kwargs
|
||||||
|
self._current_timestep = None
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
# 2. Preprocess image
|
||||||
|
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||||
|
image = self.image_processor.resize(image, calculated_height, calculated_width)
|
||||||
|
prompt_image = image
|
||||||
|
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
|
||||||
|
image = image.unsqueeze(2)
|
||||||
|
image = image.to(dtype=self.text_encoder.dtype)
|
||||||
|
|
||||||
|
if prompt is None or prompt == "" or prompt == " ":
|
||||||
|
prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device)
|
||||||
|
|
||||||
|
# 3. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
has_neg_prompt = negative_prompt is not None or (
|
||||||
|
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if true_cfg_scale > 1 and not has_neg_prompt:
|
||||||
|
logger.warning(
|
||||||
|
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
||||||
|
)
|
||||||
|
elif true_cfg_scale <= 1 and has_neg_prompt:
|
||||||
|
logger.warning(
|
||||||
|
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||||
|
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
prompt_embeds_mask=prompt_embeds_mask,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
if do_true_cfg:
|
||||||
|
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||||
|
prompt=negative_prompt,
|
||||||
|
prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.config.in_channels // 4
|
||||||
|
latents, image_latents = self.prepare_latents(
|
||||||
|
image,
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
layers,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
img_shapes = [
|
||||||
|
[
|
||||||
|
*[
|
||||||
|
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)
|
||||||
|
for _ in range(layers + 1)
|
||||||
|
],
|
||||||
|
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
|
||||||
|
]
|
||||||
|
] * batch_size
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||||
|
image_seq_len = latents.shape[1]
|
||||||
|
base_seqlen = 256 * 256 / 16 / 16
|
||||||
|
mu = (image_latents.shape[1] / base_seqlen) ** 0.5
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# handle guidance
|
||||||
|
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||||
|
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
||||||
|
elif self.transformer.config.guidance_embeds:
|
||||||
|
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||||
|
guidance = guidance.expand(latents.shape[0])
|
||||||
|
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
||||||
|
logger.warning(
|
||||||
|
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
||||||
|
)
|
||||||
|
guidance = None
|
||||||
|
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||||
|
guidance = None
|
||||||
|
|
||||||
|
if self.attention_kwargs is None:
|
||||||
|
self._attention_kwargs = {}
|
||||||
|
|
||||||
|
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||||
|
negative_txt_seq_lens = (
|
||||||
|
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||||
|
)
|
||||||
|
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
|
||||||
|
# 6. Denoising loop
|
||||||
|
self.scheduler.set_begin_index(0)
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t
|
||||||
|
|
||||||
|
latent_model_input = latents
|
||||||
|
if image_latents is not None:
|
||||||
|
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
with self.transformer.cache_context("cond"):
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens=txt_seq_lens,
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
additional_t_cond=is_rgb,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
noise_pred = noise_pred[:, : latents.size(1)]
|
||||||
|
|
||||||
|
if do_true_cfg:
|
||||||
|
with self.transformer.cache_context("uncond"):
|
||||||
|
neg_noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens=negative_txt_seq_lens,
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
additional_t_cond=is_rgb,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||||
|
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||||
|
|
||||||
|
if cfg_normalize:
|
||||||
|
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||||
|
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||||
|
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||||
|
else:
|
||||||
|
noise_pred = comb_pred
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents_dtype = latents.dtype
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
if latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
latents = latents.to(latents_dtype)
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
self._current_timestep = None
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
else:
|
||||||
|
latents = self._unpack_latents(latents, height, width, layers, self.vae_scale_factor)
|
||||||
|
latents = latents.to(self.vae.dtype)
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(self.vae.config.latents_mean)
|
||||||
|
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||||
|
.to(latents.device, latents.dtype)
|
||||||
|
)
|
||||||
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||||
|
latents.device, latents.dtype
|
||||||
|
)
|
||||||
|
latents = latents / latents_std + latents_mean
|
||||||
|
|
||||||
|
b, c, f, h, w = latents.shape
|
||||||
|
|
||||||
|
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
|
||||||
|
|
||||||
|
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
|
||||||
|
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||||
|
|
||||||
|
image = image.squeeze(2)
|
||||||
|
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
images = []
|
||||||
|
for bidx in range(b):
|
||||||
|
images.append(image[bidx * f : (bidx + 1) * f])
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (images,)
|
||||||
|
|
||||||
|
return QwenImagePipelineOutput(images=images)
|
||||||
@@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
|
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
|
||||||
_import_structure["pipeline_z_image"] = ["ZImagePipeline"]
|
_import_structure["pipeline_z_image"] = ["ZImagePipeline"]
|
||||||
|
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||||
|
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||||
|
|
||||||
|
|
||||||
@@ -36,6 +38,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
else:
|
else:
|
||||||
from .pipeline_output import ZImagePipelineOutput
|
from .pipeline_output import ZImagePipelineOutput
|
||||||
from .pipeline_z_image import ZImagePipeline
|
from .pipeline_z_image import ZImagePipeline
|
||||||
|
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||||
|
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
725
src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
Normal file
725
src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, PreTrainedModel
|
||||||
|
|
||||||
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
|
from ...loaders import FromSingleFileMixin
|
||||||
|
from ...models.autoencoders import AutoencoderKL
|
||||||
|
from ...models.controlnets import ZImageControlNetModel
|
||||||
|
from ...models.transformers import ZImageTransformer2DModel
|
||||||
|
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import logging, replace_example_docstring
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from .pipeline_output import ZImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```py
|
||||||
|
>>> import torch
|
||||||
|
>>> from diffusers import ZImageControlNetPipeline
|
||||||
|
>>> from diffusers import ZImageControlNetModel
|
||||||
|
>>> from diffusers.utils import load_image
|
||||||
|
>>> from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
>>> controlnet = ZImageControlNetModel.from_single_file(
|
||||||
|
... hf_hub_download(
|
||||||
|
... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union",
|
||||||
|
... filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors",
|
||||||
|
... ),
|
||||||
|
... torch_dtype=torch.bfloat16,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # 2.1
|
||||||
|
>>> # controlnet = ZImageControlNetModel.from_single_file(
|
||||||
|
>>> # hf_hub_download(
|
||||||
|
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||||
|
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
|
||||||
|
>>> # ),
|
||||||
|
>>> # torch_dtype=torch.bfloat16,
|
||||||
|
>>> # )
|
||||||
|
|
||||||
|
>>> # 2.0 - `config` is required
|
||||||
|
>>> # controlnet = ZImageControlNetModel.from_single_file(
|
||||||
|
>>> # hf_hub_download(
|
||||||
|
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||||
|
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
|
||||||
|
>>> # ),
|
||||||
|
>>> # torch_dtype=torch.bfloat16,
|
||||||
|
>>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||||
|
>>> # )
|
||||||
|
|
||||||
|
>>> pipe = ZImageControlNetPipeline.from_pretrained(
|
||||||
|
... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||||
|
... )
|
||||||
|
>>> pipe.to("cuda")
|
||||||
|
|
||||||
|
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||||
|
>>> # (1) Use flash attention 2
|
||||||
|
>>> # pipe.transformer.set_attention_backend("flash")
|
||||||
|
>>> # (2) Use flash attention 3
|
||||||
|
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||||
|
|
||||||
|
>>> control_image = load_image(
|
||||||
|
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true"
|
||||||
|
... )
|
||||||
|
>>> prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
|
||||||
|
>>> image = pipe(
|
||||||
|
... prompt,
|
||||||
|
... control_image=control_image,
|
||||||
|
... controlnet_conditioning_scale=0.75,
|
||||||
|
... height=1728,
|
||||||
|
... width=992,
|
||||||
|
... num_inference_steps=9,
|
||||||
|
... guidance_scale=0.0,
|
||||||
|
... generator=torch.Generator("cuda").manual_seed(43),
|
||||||
|
... ).images[0]
|
||||||
|
>>> image.save("zimage.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||||
|
def calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
base_seq_len: int = 256,
|
||||||
|
max_seq_len: int = 4096,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||||
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||||
|
_optional_components = []
|
||||||
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: PreTrainedModel,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
transformer: ZImageTransformer2DModel,
|
||||||
|
controlnet: ZImageControlNetModel,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer)
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
transformer=transformer,
|
||||||
|
controlnet=controlnet,
|
||||||
|
)
|
||||||
|
self.vae_scale_factor = (
|
||||||
|
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||||
|
)
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
prompt_embeds = self._encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
device=device,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
if negative_prompt is None:
|
||||||
|
negative_prompt = ["" for _ in prompt]
|
||||||
|
else:
|
||||||
|
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
assert len(prompt) == len(negative_prompt)
|
||||||
|
negative_prompt_embeds = self._encode_prompt(
|
||||||
|
prompt=negative_prompt,
|
||||||
|
device=device,
|
||||||
|
prompt_embeds=negative_prompt_embeds,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
negative_prompt_embeds = []
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
def _encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
) -> List[torch.FloatTensor]:
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
if prompt_embeds is not None:
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
for i, prompt_item in enumerate(prompt):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt_item},
|
||||||
|
]
|
||||||
|
prompt_item = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
prompt[i] = prompt_item
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||||
|
|
||||||
|
prompt_embeds = self.text_encoder(
|
||||||
|
input_ids=text_input_ids,
|
||||||
|
attention_mask=prompt_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).hidden_states[-2]
|
||||||
|
|
||||||
|
embeddings_list = []
|
||||||
|
|
||||||
|
for i in range(len(prompt_embeds)):
|
||||||
|
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||||
|
|
||||||
|
return embeddings_list
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||||
|
|
||||||
|
shape = (batch_size, num_channels_latents, height, width)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
if latents.shape != shape:
|
||||||
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
|
latents = latents.to(device)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
||||||
|
def prepare_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
batch_size,
|
||||||
|
num_images_per_prompt,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
guess_mode=False,
|
||||||
|
):
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and not guess_mode:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guidance_scale(self):
|
||||||
|
return self._guidance_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_classifier_free_guidance(self):
|
||||||
|
return self._guidance_scale > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def joint_attention_kwargs(self):
|
||||||
|
return self._joint_attention_kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_timesteps(self):
|
||||||
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interrupt(self):
|
||||||
|
return self._interrupt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
control_image: PipelineImageInput = None,
|
||||||
|
controlnet_conditioning_scale: Union[float, List[float]] = 0.75,
|
||||||
|
cfg_normalization: bool = False,
|
||||||
|
cfg_truncation: float = 1.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
height (`int`, *optional*, defaults to 1024):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, *optional*, defaults to 1024):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||||
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
|
will be used.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to apply configuration normalization.
|
||||||
|
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||||
|
The truncation value for configuration.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will be generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
joint_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
callback_on_step_end (`Callable`, *optional*):
|
||||||
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||||
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||||
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||||
|
`callback_on_step_end_tensor_inputs`.
|
||||||
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||||
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||||
|
Maximum sequence length to use with the `prompt`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||||
|
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||||
|
generated images.
|
||||||
|
"""
|
||||||
|
height = height or 1024
|
||||||
|
width = width or 1024
|
||||||
|
|
||||||
|
vae_scale = self.vae_scale_factor * 2
|
||||||
|
if height % vae_scale != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||||
|
f"Please adjust the height to a multiple of {vae_scale}."
|
||||||
|
)
|
||||||
|
if width % vae_scale != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||||
|
f"Please adjust the width to a multiple of {vae_scale}."
|
||||||
|
)
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._joint_attention_kwargs = joint_attention_kwargs
|
||||||
|
self._interrupt = False
|
||||||
|
self._cfg_normalization = cfg_normalization
|
||||||
|
self._cfg_truncation = cfg_truncation
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = len(prompt_embeds)
|
||||||
|
|
||||||
|
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||||
|
if prompt_embeds is not None and prompt is None:
|
||||||
|
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When `prompt_embeds` is provided without `prompt`, "
|
||||||
|
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.in_channels
|
||||||
|
|
||||||
|
control_image = self.prepare_image(
|
||||||
|
image=control_image,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
batch_size=batch_size * num_images_per_prompt,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=device,
|
||||||
|
dtype=self.vae.dtype,
|
||||||
|
)
|
||||||
|
height, width = control_image.shape[-2:]
|
||||||
|
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax")
|
||||||
|
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||||
|
control_image = control_image.unsqueeze(2)
|
||||||
|
|
||||||
|
if num_channels_latents != self.controlnet.config.control_in_dim:
|
||||||
|
# For model version 2.0
|
||||||
|
control_image = torch.cat(
|
||||||
|
[
|
||||||
|
control_image,
|
||||||
|
torch.zeros(
|
||||||
|
control_image.shape[0],
|
||||||
|
self.controlnet.config.control_in_dim - num_channels_latents,
|
||||||
|
*control_image.shape[2:],
|
||||||
|
).to(device=control_image.device, dtype=control_image.dtype),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
torch.float32,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Repeat prompt_embeds for num_images_per_prompt
|
||||||
|
if num_images_per_prompt > 1:
|
||||||
|
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||||
|
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||||
|
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||||
|
|
||||||
|
actual_batch_size = batch_size * num_images_per_prompt
|
||||||
|
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
self.scheduler.config.get("base_image_seq_len", 256),
|
||||||
|
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||||
|
self.scheduler.config.get("base_shift", 0.5),
|
||||||
|
self.scheduler.config.get("max_shift", 1.15),
|
||||||
|
)
|
||||||
|
self.scheduler.sigma_min = 0.0
|
||||||
|
scheduler_kwargs = {"mu": mu}
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=sigmas,
|
||||||
|
**scheduler_kwargs,
|
||||||
|
)
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0])
|
||||||
|
timestep = (1000 - timestep) / 1000
|
||||||
|
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||||
|
t_norm = timestep[0].item()
|
||||||
|
|
||||||
|
# Handle cfg truncation
|
||||||
|
current_guidance_scale = self.guidance_scale
|
||||||
|
if (
|
||||||
|
self.do_classifier_free_guidance
|
||||||
|
and self._cfg_truncation is not None
|
||||||
|
and float(self._cfg_truncation) <= 1
|
||||||
|
):
|
||||||
|
if t_norm > self._cfg_truncation:
|
||||||
|
current_guidance_scale = 0.0
|
||||||
|
|
||||||
|
# Run CFG only if configured AND scale is non-zero
|
||||||
|
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||||
|
|
||||||
|
if apply_cfg:
|
||||||
|
latents_typed = latents.to(self.transformer.dtype)
|
||||||
|
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||||
|
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||||
|
timestep_model_input = timestep.repeat(2)
|
||||||
|
else:
|
||||||
|
latent_model_input = latents.to(self.transformer.dtype)
|
||||||
|
prompt_embeds_model_input = prompt_embeds
|
||||||
|
timestep_model_input = timestep
|
||||||
|
|
||||||
|
latent_model_input = latent_model_input.unsqueeze(2)
|
||||||
|
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||||
|
|
||||||
|
controlnet_block_samples = self.controlnet(
|
||||||
|
latent_model_input_list,
|
||||||
|
timestep_model_input,
|
||||||
|
prompt_embeds_model_input,
|
||||||
|
control_image,
|
||||||
|
conditioning_scale=controlnet_conditioning_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_out_list = self.transformer(
|
||||||
|
latent_model_input_list,
|
||||||
|
timestep_model_input,
|
||||||
|
prompt_embeds_model_input,
|
||||||
|
controlnet_block_samples=controlnet_block_samples,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if apply_cfg:
|
||||||
|
# Perform CFG
|
||||||
|
pos_out = model_out_list[:actual_batch_size]
|
||||||
|
neg_out = model_out_list[actual_batch_size:]
|
||||||
|
|
||||||
|
noise_pred = []
|
||||||
|
for j in range(actual_batch_size):
|
||||||
|
pos = pos_out[j].float()
|
||||||
|
neg = neg_out[j].float()
|
||||||
|
|
||||||
|
pred = pos + current_guidance_scale * (pos - neg)
|
||||||
|
|
||||||
|
# Renormalization
|
||||||
|
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||||
|
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||||
|
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||||
|
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||||
|
if new_pos_norm > max_new_norm:
|
||||||
|
pred = pred * (max_new_norm / new_pos_norm)
|
||||||
|
|
||||||
|
noise_pred.append(pred)
|
||||||
|
|
||||||
|
noise_pred = torch.stack(noise_pred, dim=0)
|
||||||
|
else:
|
||||||
|
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||||
|
|
||||||
|
noise_pred = noise_pred.squeeze(2)
|
||||||
|
noise_pred = -noise_pred
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||||
|
assert latents.dtype == torch.float32
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
|
||||||
|
else:
|
||||||
|
latents = latents.to(self.vae.dtype)
|
||||||
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return ZImagePipelineOutput(images=image)
|
||||||
@@ -0,0 +1,747 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer, PreTrainedModel
|
||||||
|
|
||||||
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
|
from ...loaders import FromSingleFileMixin
|
||||||
|
from ...models.autoencoders import AutoencoderKL
|
||||||
|
from ...models.controlnets import ZImageControlNetModel
|
||||||
|
from ...models.transformers import ZImageTransformer2DModel
|
||||||
|
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import logging, replace_example_docstring
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from .pipeline_output import ZImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```py
|
||||||
|
>>> import torch
|
||||||
|
>>> from diffusers import ZImageControlNetInpaintPipeline
|
||||||
|
>>> from diffusers import ZImageControlNetModel
|
||||||
|
>>> from diffusers.utils import load_image
|
||||||
|
>>> from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
>>> controlnet = ZImageControlNetModel.from_single_file(
|
||||||
|
... hf_hub_download(
|
||||||
|
... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||||
|
... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
|
||||||
|
... ),
|
||||||
|
... torch_dtype=torch.bfloat16,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # 2.0 - `config` is required
|
||||||
|
>>> # controlnet = ZImageControlNetModel.from_single_file(
|
||||||
|
>>> # hf_hub_download(
|
||||||
|
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||||
|
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
|
||||||
|
>>> # ),
|
||||||
|
>>> # torch_dtype=torch.bfloat16,
|
||||||
|
>>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||||
|
>>> # )
|
||||||
|
|
||||||
|
>>> pipe = ZImageControlNetInpaintPipeline.from_pretrained(
|
||||||
|
... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||||
|
... )
|
||||||
|
>>> pipe.to("cuda")
|
||||||
|
|
||||||
|
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||||
|
>>> # (1) Use flash attention 2
|
||||||
|
>>> # pipe.transformer.set_attention_backend("flash")
|
||||||
|
>>> # (2) Use flash attention 3
|
||||||
|
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||||
|
|
||||||
|
>>> image = load_image(
|
||||||
|
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/inpaint.jpg?download=true"
|
||||||
|
... )
|
||||||
|
>>> mask_image = load_image(
|
||||||
|
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/mask.jpg?download=true"
|
||||||
|
... )
|
||||||
|
>>> control_image = load_image(
|
||||||
|
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/pose.jpg?download=true"
|
||||||
|
... )
|
||||||
|
>>> prompt = "一位年轻女子站在阳光明媚的海岸线上,画面为全身竖构图,身体微微侧向右侧,左手自然下垂,右臂弯曲扶在腰间,她的手指清晰可见,站姿放松而略带羞涩。她身穿轻盈的白色连衣裙,裙摆在海风中轻轻飘动,布料半透、质感柔软。女子拥有一头鲜艳的及腰紫色长发,被海风吹起,在身侧轻盈飞舞,发间系着一个精致的黑色蝴蝶结,与发色形成对比。她面容清秀,眉目精致,肤色白皙细腻,表情温柔略显羞涩,微微低头,眼神静静望向远处的海平线,流露出甜美的青春气息与若有所思的神情。背景是辽阔无垠的海洋与蔚蓝天空,阳光从侧前方洒下,海面波光粼粼,泛着温暖的金色光晕,天空清澈明亮,云朵稀薄,整体色调清新唯美。"
|
||||||
|
>>> image = pipe(
|
||||||
|
... prompt,
|
||||||
|
... image=image,
|
||||||
|
... mask_image=mask_image,
|
||||||
|
... control_image=control_image,
|
||||||
|
... controlnet_conditioning_scale=0.75,
|
||||||
|
... height=1728,
|
||||||
|
... width=992,
|
||||||
|
... num_inference_steps=25,
|
||||||
|
... guidance_scale=0.0,
|
||||||
|
... generator=torch.Generator("cuda").manual_seed(43),
|
||||||
|
... ).images[0]
|
||||||
|
>>> image.save("zimage-inpaint.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||||
|
def calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
base_seq_len: int = 256,
|
||||||
|
max_seq_len: int = 4096,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||||
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||||
|
_optional_components = []
|
||||||
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: PreTrainedModel,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
transformer: ZImageTransformer2DModel,
|
||||||
|
controlnet: ZImageControlNetModel,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if transformer.in_channels == controlnet.config.control_in_dim:
|
||||||
|
raise ValueError(
|
||||||
|
"ZImageControlNetInpaintPipeline is not compatible with `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union`, use `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0`."
|
||||||
|
)
|
||||||
|
controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer)
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
transformer=transformer,
|
||||||
|
controlnet=controlnet,
|
||||||
|
)
|
||||||
|
self.vae_scale_factor = (
|
||||||
|
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||||
|
)
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||||
|
self.mask_processor = VaeImageProcessor(
|
||||||
|
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
prompt_embeds = self._encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
device=device,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
if negative_prompt is None:
|
||||||
|
negative_prompt = ["" for _ in prompt]
|
||||||
|
else:
|
||||||
|
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
assert len(prompt) == len(negative_prompt)
|
||||||
|
negative_prompt_embeds = self._encode_prompt(
|
||||||
|
prompt=negative_prompt,
|
||||||
|
device=device,
|
||||||
|
prompt_embeds=negative_prompt_embeds,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
negative_prompt_embeds = []
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
def _encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
) -> List[torch.FloatTensor]:
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
if prompt_embeds is not None:
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
for i, prompt_item in enumerate(prompt):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt_item},
|
||||||
|
]
|
||||||
|
prompt_item = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
prompt[i] = prompt_item
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||||
|
|
||||||
|
prompt_embeds = self.text_encoder(
|
||||||
|
input_ids=text_input_ids,
|
||||||
|
attention_mask=prompt_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).hidden_states[-2]
|
||||||
|
|
||||||
|
embeddings_list = []
|
||||||
|
|
||||||
|
for i in range(len(prompt_embeds)):
|
||||||
|
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||||
|
|
||||||
|
return embeddings_list
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||||
|
|
||||||
|
shape = (batch_size, num_channels_latents, height, width)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
if latents.shape != shape:
|
||||||
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
|
latents = latents.to(device)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
||||||
|
def prepare_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
batch_size,
|
||||||
|
num_images_per_prompt,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
guess_mode=False,
|
||||||
|
):
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and not guess_mode:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guidance_scale(self):
|
||||||
|
return self._guidance_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_classifier_free_guidance(self):
|
||||||
|
return self._guidance_scale > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def joint_attention_kwargs(self):
|
||||||
|
return self._joint_attention_kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_timesteps(self):
|
||||||
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interrupt(self):
|
||||||
|
return self._interrupt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
image: PipelineImageInput = None,
|
||||||
|
mask_image: PipelineImageInput = None,
|
||||||
|
control_image: PipelineImageInput = None,
|
||||||
|
controlnet_conditioning_scale: Union[float, List[float]] = 0.75,
|
||||||
|
cfg_normalization: bool = False,
|
||||||
|
cfg_truncation: float = 1.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
height (`int`, *optional*, defaults to 1024):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, *optional*, defaults to 1024):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||||
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
|
will be used.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to apply configuration normalization.
|
||||||
|
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||||
|
The truncation value for configuration.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will be generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
joint_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
callback_on_step_end (`Callable`, *optional*):
|
||||||
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||||
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||||
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||||
|
`callback_on_step_end_tensor_inputs`.
|
||||||
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||||
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||||
|
Maximum sequence length to use with the `prompt`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||||
|
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||||
|
generated images.
|
||||||
|
"""
|
||||||
|
height = height or 1024
|
||||||
|
width = width or 1024
|
||||||
|
|
||||||
|
vae_scale = self.vae_scale_factor * 2
|
||||||
|
if height % vae_scale != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||||
|
f"Please adjust the height to a multiple of {vae_scale}."
|
||||||
|
)
|
||||||
|
if width % vae_scale != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||||
|
f"Please adjust the width to a multiple of {vae_scale}."
|
||||||
|
)
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._joint_attention_kwargs = joint_attention_kwargs
|
||||||
|
self._interrupt = False
|
||||||
|
self._cfg_normalization = cfg_normalization
|
||||||
|
self._cfg_truncation = cfg_truncation
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = len(prompt_embeds)
|
||||||
|
|
||||||
|
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||||
|
if prompt_embeds is not None and prompt is None:
|
||||||
|
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When `prompt_embeds` is provided without `prompt`, "
|
||||||
|
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.in_channels
|
||||||
|
|
||||||
|
control_image = self.prepare_image(
|
||||||
|
image=control_image,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
batch_size=batch_size * num_images_per_prompt,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=device,
|
||||||
|
dtype=self.vae.dtype,
|
||||||
|
)
|
||||||
|
height, width = control_image.shape[-2:]
|
||||||
|
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax")
|
||||||
|
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||||
|
control_image = control_image.unsqueeze(2)
|
||||||
|
|
||||||
|
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||||
|
mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to(
|
||||||
|
device=control_image.device, dtype=control_image.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
init_image = self.prepare_image(
|
||||||
|
image=image,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
batch_size=batch_size * num_images_per_prompt,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=device,
|
||||||
|
dtype=self.vae.dtype,
|
||||||
|
)
|
||||||
|
height, width = init_image.shape[-2:]
|
||||||
|
init_image = init_image * (mask_condition < 0.5)
|
||||||
|
init_image = retrieve_latents(self.vae.encode(init_image), generator=generator, sample_mode="argmax")
|
||||||
|
init_image = (init_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||||
|
init_image = init_image.unsqueeze(2)
|
||||||
|
|
||||||
|
mask_condition = F.interpolate(1 - mask_condition[:, :1], size=init_image.size()[-2:], mode="nearest").to(
|
||||||
|
device=control_image.device, dtype=control_image.dtype
|
||||||
|
)
|
||||||
|
mask_condition = mask_condition.unsqueeze(2)
|
||||||
|
|
||||||
|
control_image = torch.cat([control_image, mask_condition, init_image], dim=1)
|
||||||
|
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
torch.float32,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Repeat prompt_embeds for num_images_per_prompt
|
||||||
|
if num_images_per_prompt > 1:
|
||||||
|
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||||
|
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||||
|
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||||
|
|
||||||
|
actual_batch_size = batch_size * num_images_per_prompt
|
||||||
|
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
self.scheduler.config.get("base_image_seq_len", 256),
|
||||||
|
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||||
|
self.scheduler.config.get("base_shift", 0.5),
|
||||||
|
self.scheduler.config.get("max_shift", 1.15),
|
||||||
|
)
|
||||||
|
self.scheduler.sigma_min = 0.0
|
||||||
|
scheduler_kwargs = {"mu": mu}
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=sigmas,
|
||||||
|
**scheduler_kwargs,
|
||||||
|
)
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0])
|
||||||
|
timestep = (1000 - timestep) / 1000
|
||||||
|
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||||
|
t_norm = timestep[0].item()
|
||||||
|
|
||||||
|
# Handle cfg truncation
|
||||||
|
current_guidance_scale = self.guidance_scale
|
||||||
|
if (
|
||||||
|
self.do_classifier_free_guidance
|
||||||
|
and self._cfg_truncation is not None
|
||||||
|
and float(self._cfg_truncation) <= 1
|
||||||
|
):
|
||||||
|
if t_norm > self._cfg_truncation:
|
||||||
|
current_guidance_scale = 0.0
|
||||||
|
|
||||||
|
# Run CFG only if configured AND scale is non-zero
|
||||||
|
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||||
|
|
||||||
|
if apply_cfg:
|
||||||
|
latents_typed = latents.to(self.transformer.dtype)
|
||||||
|
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||||
|
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||||
|
timestep_model_input = timestep.repeat(2)
|
||||||
|
else:
|
||||||
|
latent_model_input = latents.to(self.transformer.dtype)
|
||||||
|
prompt_embeds_model_input = prompt_embeds
|
||||||
|
timestep_model_input = timestep
|
||||||
|
|
||||||
|
latent_model_input = latent_model_input.unsqueeze(2)
|
||||||
|
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||||
|
|
||||||
|
controlnet_block_samples = self.controlnet(
|
||||||
|
latent_model_input_list,
|
||||||
|
timestep_model_input,
|
||||||
|
prompt_embeds_model_input,
|
||||||
|
control_image,
|
||||||
|
conditioning_scale=controlnet_conditioning_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_out_list = self.transformer(
|
||||||
|
latent_model_input_list,
|
||||||
|
timestep_model_input,
|
||||||
|
prompt_embeds_model_input,
|
||||||
|
controlnet_block_samples=controlnet_block_samples,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if apply_cfg:
|
||||||
|
# Perform CFG
|
||||||
|
pos_out = model_out_list[:actual_batch_size]
|
||||||
|
neg_out = model_out_list[actual_batch_size:]
|
||||||
|
|
||||||
|
noise_pred = []
|
||||||
|
for j in range(actual_batch_size):
|
||||||
|
pos = pos_out[j].float()
|
||||||
|
neg = neg_out[j].float()
|
||||||
|
|
||||||
|
pred = pos + current_guidance_scale * (pos - neg)
|
||||||
|
|
||||||
|
# Renormalization
|
||||||
|
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||||
|
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||||
|
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||||
|
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||||
|
if new_pos_norm > max_new_norm:
|
||||||
|
pred = pred * (max_new_norm / new_pos_norm)
|
||||||
|
|
||||||
|
noise_pred.append(pred)
|
||||||
|
|
||||||
|
noise_pred = torch.stack(noise_pred, dim=0)
|
||||||
|
else:
|
||||||
|
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||||
|
|
||||||
|
noise_pred = noise_pred.squeeze(2)
|
||||||
|
noise_pred = -noise_pred
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||||
|
assert latents.dtype == torch.float32
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
|
||||||
|
else:
|
||||||
|
latents = latents.to(self.vae.dtype)
|
||||||
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return ZImagePipelineOutput(images=image)
|
||||||
@@ -217,6 +217,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
rescale_betas_zero_snr: bool = False,
|
rescale_betas_zero_snr: bool = False,
|
||||||
use_dynamic_shifting: bool = False,
|
use_dynamic_shifting: bool = False,
|
||||||
time_shift_type: Literal["exponential"] = "exponential",
|
time_shift_type: Literal["exponential"] = "exponential",
|
||||||
|
sigma_min: Optional[float] = None,
|
||||||
|
sigma_max: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||||
@@ -350,7 +352,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
log_sigmas = np.log(sigmas)
|
log_sigmas = np.log(sigmas)
|
||||||
sigmas = np.flip(sigmas).copy()
|
sigmas = np.flip(sigmas).copy()
|
||||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
if self.config.use_flow_sigmas:
|
||||||
|
sigmas = sigmas / (sigmas + 1)
|
||||||
|
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||||
|
else:
|
||||||
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||||
|
|
||||||
if self.config.final_sigmas_type == "sigma_min":
|
if self.config.final_sigmas_type == "sigma_min":
|
||||||
sigma_last = sigmas[-1]
|
sigma_last = sigmas[-1]
|
||||||
elif self.config.final_sigmas_type == "zero":
|
elif self.config.final_sigmas_type == "zero":
|
||||||
|
|||||||
@@ -1777,6 +1777,21 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageControlNetModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ZImageTransformer2DModel(metaclass=DummyObject):
|
class ZImageTransformer2DModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -767,6 +767,21 @@ class ConsisIDPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
@@ -2297,6 +2312,21 @@ class QwenImageInpaintPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageLayeredPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class QwenImagePipeline(metaclass=DummyObject):
|
class QwenImagePipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
@@ -3842,6 +3872,36 @@ class WuerstchenPriorPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageControlNetInpaintPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageControlNetPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._dtype = torch.float32
|
self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False)
|
||||||
|
|
||||||
def check_text_safety(self, prompt: str) -> bool:
|
def check_text_safety(self, prompt: str) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -35,13 +35,14 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
|||||||
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
|
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
|
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None):
|
||||||
self._dtype = dtype
|
module = super().to(device=device, dtype=dtype)
|
||||||
|
return module
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return None
|
return self._device_tracker.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
return self._dtype
|
return self._device_tracker.dtype
|
||||||
|
|||||||
337
tests/pipelines/cosmos/test_cosmos2_5_predict.py
Normal file
337
tests/pipelines/cosmos/test_cosmos2_5_predict.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKLWan,
|
||||||
|
Cosmos2_5_PredictBasePipeline,
|
||||||
|
CosmosTransformer3DModel,
|
||||||
|
UniPCMultistepScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...testing_utils import enable_full_determinism, torch_device
|
||||||
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||||
|
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||||
|
from .cosmos_guardrail import DummyCosmosSafetyChecker
|
||||||
|
|
||||||
|
|
||||||
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
|
class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline):
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(*args, **kwargs):
|
||||||
|
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
|
||||||
|
safety_checker = DummyCosmosSafetyChecker()
|
||||||
|
device_map = kwargs.get("device_map", "cpu")
|
||||||
|
torch_dtype = kwargs.get("torch_dtype")
|
||||||
|
if device_map is not None or torch_dtype is not None:
|
||||||
|
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
|
||||||
|
kwargs["safety_checker"] = safety_checker
|
||||||
|
return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
pipeline_class = Cosmos2_5_PredictBaseWrapper
|
||||||
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||||
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||||
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
required_optional_params = frozenset(
|
||||||
|
[
|
||||||
|
"num_inference_steps",
|
||||||
|
"generator",
|
||||||
|
"latents",
|
||||||
|
"return_dict",
|
||||||
|
"callback_on_step_end",
|
||||||
|
"callback_on_step_end_tensor_inputs",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
supports_dduf = False
|
||||||
|
test_xformers_attention = False
|
||||||
|
test_layerwise_casting = True
|
||||||
|
test_group_offloading = True
|
||||||
|
|
||||||
|
def get_dummy_components(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
transformer = CosmosTransformer3DModel(
|
||||||
|
in_channels=16 + 1,
|
||||||
|
out_channels=16,
|
||||||
|
num_attention_heads=2,
|
||||||
|
attention_head_dim=16,
|
||||||
|
num_layers=2,
|
||||||
|
mlp_ratio=2,
|
||||||
|
text_embed_dim=32,
|
||||||
|
adaln_lora_dim=4,
|
||||||
|
max_size=(4, 32, 32),
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
rope_scale=(2.0, 1.0, 1.0),
|
||||||
|
concat_padding_mask=True,
|
||||||
|
extra_pos_embed_type="learnable",
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKLWan(
|
||||||
|
base_dim=3,
|
||||||
|
z_dim=16,
|
||||||
|
dim_mult=[1, 1, 1, 1],
|
||||||
|
num_res_blocks=1,
|
||||||
|
temperal_downsample=[False, True, True],
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
scheduler = UniPCMultistepScheduler()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = Qwen2_5_VLConfig(
|
||||||
|
text_config={
|
||||||
|
"hidden_size": 16,
|
||||||
|
"intermediate_size": 16,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 2,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [1, 1, 2],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default",
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"depth": 2,
|
||||||
|
"hidden_size": 16,
|
||||||
|
"intermediate_size": 16,
|
||||||
|
"num_heads": 2,
|
||||||
|
"out_hidden_size": 16,
|
||||||
|
},
|
||||||
|
hidden_size=16,
|
||||||
|
vocab_size=152064,
|
||||||
|
vision_end_token_id=151653,
|
||||||
|
vision_start_token_id=151652,
|
||||||
|
vision_token_id=151654,
|
||||||
|
)
|
||||||
|
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||||
|
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"transformer": transformer,
|
||||||
|
"vae": vae,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"safety_checker": DummyCosmosSafetyChecker(),
|
||||||
|
}
|
||||||
|
return components
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"prompt": "dance monkey",
|
||||||
|
"negative_prompt": "bad quality",
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 3.0,
|
||||||
|
"height": 32,
|
||||||
|
"width": 32,
|
||||||
|
"num_frames": 3,
|
||||||
|
"max_sequence_length": 16,
|
||||||
|
"output_type": "pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_components_function(self):
|
||||||
|
init_components = self.get_dummy_components()
|
||||||
|
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
|
||||||
|
pipe = self.pipeline_class(**init_components)
|
||||||
|
self.assertTrue(hasattr(pipe, "components"))
|
||||||
|
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
video = pipe(**inputs).frames
|
||||||
|
generated_video = video[0]
|
||||||
|
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||||
|
self.assertTrue(torch.isfinite(generated_video).all())
|
||||||
|
|
||||||
|
def test_callback_inputs(self):
|
||||||
|
sig = inspect.signature(self.pipeline_class.__call__)
|
||||||
|
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||||
|
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||||
|
|
||||||
|
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||||
|
return
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
self.assertTrue(
|
||||||
|
hasattr(pipe, "_callback_tensor_inputs"),
|
||||||
|
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||||
|
)
|
||||||
|
|
||||||
|
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||||
|
for tensor_name in callback_kwargs.keys():
|
||||||
|
assert tensor_name in pipe._callback_tensor_inputs
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||||
|
for tensor_name in pipe._callback_tensor_inputs:
|
||||||
|
assert tensor_name in callback_kwargs
|
||||||
|
for tensor_name in callback_kwargs.keys():
|
||||||
|
assert tensor_name in pipe._callback_tensor_inputs
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||||
|
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||||
|
_ = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
inputs["callback_on_step_end"] = callback_inputs_all
|
||||||
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||||
|
_ = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||||
|
is_last = i == (pipe.num_timesteps - 1)
|
||||||
|
if is_last:
|
||||||
|
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||||
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||||
|
output = pipe(**inputs)[0]
|
||||||
|
assert output.abs().sum() < 1e10
|
||||||
|
|
||||||
|
def test_inference_batch_single_identical(self):
|
||||||
|
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
|
||||||
|
|
||||||
|
def test_attention_slicing_forward_pass(
|
||||||
|
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||||
|
):
|
||||||
|
if not getattr(self, "test_attention_slicing", True):
|
||||||
|
return
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
for component in pipe.components.values():
|
||||||
|
if hasattr(component, "set_default_attn_processor"):
|
||||||
|
component.set_default_attn_processor()
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
generator_device = "cpu"
|
||||||
|
inputs = self.get_dummy_inputs(generator_device)
|
||||||
|
output_without_slicing = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
pipe.enable_attention_slicing(slice_size=1)
|
||||||
|
inputs = self.get_dummy_inputs(generator_device)
|
||||||
|
output_with_slicing1 = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
pipe.enable_attention_slicing(slice_size=2)
|
||||||
|
inputs = self.get_dummy_inputs(generator_device)
|
||||||
|
output_with_slicing2 = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
if test_max_difference:
|
||||||
|
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||||
|
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||||
|
self.assertLess(
|
||||||
|
max(max_diff1, max_diff2),
|
||||||
|
expected_max_diff,
|
||||||
|
"Attention slicing should not affect the inference results",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||||
|
self.pipeline_class._optional_components.remove("safety_checker")
|
||||||
|
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||||
|
self.pipeline_class._optional_components.append("safety_checker")
|
||||||
|
|
||||||
|
def test_serialization_with_variants(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
model_components = [
|
||||||
|
component_name
|
||||||
|
for component_name, component in pipe.components.items()
|
||||||
|
if isinstance(component, torch.nn.Module)
|
||||||
|
]
|
||||||
|
model_components.remove("safety_checker")
|
||||||
|
variant = "fp16"
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||||
|
|
||||||
|
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
for subfolder in os.listdir(tmpdir):
|
||||||
|
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||||
|
folder_path = os.path.join(tmpdir, subfolder)
|
||||||
|
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||||
|
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||||
|
|
||||||
|
def test_torch_dtype_dict(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
if not components:
|
||||||
|
self.skipTest("No dummy components defined.")
|
||||||
|
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
|
||||||
|
specified_key = next(iter(components.keys()))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||||
|
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
|
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||||
|
loaded_pipe = self.pipeline_class.from_pretrained(
|
||||||
|
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, component in loaded_pipe.components.items():
|
||||||
|
if name == "safety_checker":
|
||||||
|
continue
|
||||||
|
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
|
||||||
|
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
|
||||||
|
self.assertEqual(
|
||||||
|
component.dtype,
|
||||||
|
expected_dtype,
|
||||||
|
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
|
||||||
|
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
|
||||||
|
"too large and slow to run on CI."
|
||||||
|
)
|
||||||
|
def test_encode_prompt_works_in_isolation(self):
|
||||||
|
pass
|
||||||
@@ -399,3 +399,32 @@ class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
|
|||||||
|
|
||||||
def test_exponential_sigmas(self):
|
def test_exponential_sigmas(self):
|
||||||
self.check_over_configs(use_exponential_sigmas=True)
|
self.check_over_configs(use_exponential_sigmas=True)
|
||||||
|
|
||||||
|
def test_flow_and_karras_sigmas(self):
|
||||||
|
self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True)
|
||||||
|
|
||||||
|
def test_flow_and_karras_sigmas_values(self):
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
num_inference_steps = 5
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
sigma_min=0.01,
|
||||||
|
sigma_max=200.0,
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
use_karras_sigmas=True,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
|
||||||
|
|
||||||
|
expected_sigmas = [
|
||||||
|
0.9950248599052429,
|
||||||
|
0.9787454605102539,
|
||||||
|
0.8774884343147278,
|
||||||
|
0.3604971766471863,
|
||||||
|
0.009900986216962337,
|
||||||
|
0.0, # 0 appended as default
|
||||||
|
]
|
||||||
|
expected_sigmas = torch.tensor(expected_sigmas)
|
||||||
|
expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64)
|
||||||
|
expected_timesteps = expected_timesteps[0:-1]
|
||||||
|
self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas))
|
||||||
|
self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps))
|
||||||
|
|||||||
Reference in New Issue
Block a user