mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-26 14:24:41 +08:00
Compare commits
11 Commits
enable-cp-
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e27287cff | ||
|
|
f6b6a7181e | ||
|
|
52766e6a69 | ||
|
|
973a077c6a | ||
|
|
0c4f6c9cff | ||
|
|
262ce19bff | ||
|
|
f7753b1bc8 | ||
|
|
b5309683cb | ||
|
|
55463f7ace | ||
|
|
f9c1e612fb | ||
|
|
87f7d11143 |
@@ -70,6 +70,12 @@ output.save("output.png")
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
|
||||
@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
@@ -21,8 +21,8 @@ from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
CLIPImageProcessor,
|
||||
MT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
tokenizer_2=T5Tokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -1,11 +1,94 @@
|
||||
"""
|
||||
# Cosmos 2 Predict
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2-2B-Text2Image
|
||||
```
|
||||
|
||||
convert checkpoint
|
||||
```bash
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
|
||||
--text_encoder_path google-t5/t5-11b \
|
||||
--tokenizer_path google-t5/t5-11b \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/cosmos-p2-t2i-2b \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
# Cosmos 2.5 Predict
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-2B
|
||||
```
|
||||
|
||||
Convert checkpoint
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
## 14B
|
||||
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-14B
|
||||
```
|
||||
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLCosmos,
|
||||
@@ -17,7 +100,9 @@ from diffusers import (
|
||||
CosmosVideoToWorldPipeline,
|
||||
EDMEulerScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -233,6 +318,44 @@ TRANSFORMER_CONFIGS = {
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-2B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-14B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -334,6 +457,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
elif "Cosmos-2.0" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
elif "Cosmos-2.5" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -347,6 +473,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
print(key, "->", new_key, flush=True)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -355,6 +482,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
expected_keys = set(transformer.state_dict().keys())
|
||||
mapped_keys = set(original_state_dict.keys())
|
||||
missing_keys = expected_keys - mapped_keys
|
||||
unexpected_keys = mapped_keys - expected_keys
|
||||
if missing_keys:
|
||||
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
|
||||
for k in missing_keys:
|
||||
print(k)
|
||||
sys.exit(1)
|
||||
if unexpected_keys:
|
||||
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
|
||||
for k in unexpected_keys:
|
||||
print(k)
|
||||
sys.exit(2)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
@@ -444,6 +586,34 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def save_pipeline_cosmos2_5(args, transformer, vae):
|
||||
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
|
||||
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
text_encoder_path, torch_dtype="auto", device_map="cpu"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
use_karras_sigmas=True,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
sigma_max=200.0,
|
||||
sigma_min=0.01,
|
||||
)
|
||||
|
||||
pipe = Cosmos2_5_PredictBasePipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
safety_checker=lambda *args, **kwargs: None,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
@@ -451,10 +621,10 @@ def get_args():
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
)
|
||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--text_encoder_path", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None)
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
@@ -477,8 +647,6 @@ if __name__ == "__main__":
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None
|
||||
assert args.vae_type is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
weights_only = "Cosmos-1.0" in args.transformer_type
|
||||
@@ -490,17 +658,26 @@ if __name__ == "__main__":
|
||||
if args.vae_type is not None:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
vae = convert_vae(args.vae_type)
|
||||
else:
|
||||
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_1_0(args, transformer, vae)
|
||||
elif "Cosmos-2.0" in args.transformer_type:
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_2_0(args, transformer, vae)
|
||||
elif "Cosmos-2.5" in args.transformer_type:
|
||||
save_pipeline_cosmos2_5(args, transformer, vae)
|
||||
else:
|
||||
assert False
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
|
||||
@@ -279,6 +279,7 @@ else:
|
||||
"WanAnimateTransformer3DModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"ZImageControlNetModel",
|
||||
"ZImageTransformer2DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
@@ -462,6 +463,7 @@ else:
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"Cosmos2VideoToWorldPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
@@ -564,6 +566,7 @@ else:
|
||||
"QwenImageEditPlusPipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImageLayeredPipeline",
|
||||
"QwenImagePipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
@@ -669,7 +672,10 @@ else:
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
@@ -1016,6 +1022,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
ZImageControlNetModel,
|
||||
ZImageTransformer2DModel,
|
||||
attention_backend,
|
||||
)
|
||||
@@ -1170,6 +1177,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
@@ -1272,6 +1280,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImageLayeredPipeline,
|
||||
QwenImagePipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaControlNetPipeline,
|
||||
@@ -1375,7 +1384,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ if is_torch_available():
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .guider_utils import BaseGuidance
|
||||
from .magnitude_aware_guidance import MagnitudeAwareGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
|
||||
159
src/diffusers/guiders/magnitude_aware_guidance.py
Normal file
159
src/diffusers/guiders/magnitude_aware_guidance.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class MagnitudeAwareGuidance(BaseGuidance):
|
||||
"""
|
||||
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `10.0`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
alpha (`float`, defaults to `8.0`):
|
||||
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
|
||||
guidance scale when the magnitude of the guidance update is large.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 10.0,
|
||||
alpha: float = 8.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.alpha = alpha
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
if not self._is_mambo_g_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = mambo_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.alpha,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_mambo_g_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_mambo_g_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def mambo_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
alpha: float = 8.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
dim = list(range(1, len(pred_cond.shape)))
|
||||
diff = pred_cond - pred_uncond
|
||||
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
|
||||
guidance_scale_final = (
|
||||
guidance_scale * torch.exp(-alpha * ratio)
|
||||
if use_original_formulation
|
||||
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
|
||||
)
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale_final * diff
|
||||
|
||||
return pred
|
||||
@@ -49,6 +49,7 @@ from .single_file_utils import (
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||
convert_z_image_transformer_checkpoint_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
@@ -172,11 +173,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"ZImageControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
||||
model_state_dict_keys = set(model_state_dict.keys())
|
||||
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
|
||||
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
|
||||
is_match = model_state_dict_keys == checkpoint_state_dict_keys
|
||||
return not (is_subset and is_match)
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
|
||||
@@ -121,6 +121,8 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"z-image-turbo": "cap_embedder.0.weight",
|
||||
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
|
||||
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
|
||||
"sana": [
|
||||
"blocks.0.cross_attn.q_linear.weight",
|
||||
"blocks.0.cross_attn.q_linear.bias",
|
||||
@@ -220,6 +222,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
||||
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
||||
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
|
||||
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
|
||||
"z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -779,6 +783,12 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
|
||||
model_type = "z-image-turbo-controlnet-2.x"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
|
||||
model_type = "z-image-turbo-controlnet"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3885,3 +3895,17 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
|
||||
if config["add_control_noise_refiner"] is None:
|
||||
return checkpoint
|
||||
elif config["add_control_noise_refiner"] == "control_noise_refiner":
|
||||
return checkpoint
|
||||
elif config["add_control_noise_refiner"] == "control_layers":
|
||||
converted_state_dict = {
|
||||
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
|
||||
}
|
||||
return converted_state_dict
|
||||
else:
|
||||
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
|
||||
|
||||
@@ -66,6 +66,7 @@ if is_torch_available():
|
||||
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
||||
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
|
||||
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
@@ -181,6 +182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SD3MultiControlNetModel,
|
||||
SparseControlNetModel,
|
||||
UNetControlNetXSModel,
|
||||
ZImageControlNetModel,
|
||||
)
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
@@ -394,6 +394,7 @@ class QwenImageEncoder3d(nn.Module):
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -410,7 +411,7 @@ class QwenImageEncoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
@@ -570,6 +571,7 @@ class QwenImageDecoder3d(nn.Module):
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -621,7 +623,7 @@ class QwenImageDecoder3d(nn.Module):
|
||||
|
||||
# output blocks
|
||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -684,6 +686,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
input_channels: int = 3,
|
||||
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
||||
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
||||
) -> None:
|
||||
@@ -695,13 +698,13 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
self.encoder = QwenImageEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels
|
||||
)
|
||||
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = QwenImageDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
|
||||
@@ -19,6 +19,7 @@ if is_torch_available():
|
||||
)
|
||||
from .controlnet_union import ControlNetUnionModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
||||
from .controlnet_z_image import ZImageControlNetModel
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .multicontrolnet_union import MultiControlNetUnionModel
|
||||
|
||||
|
||||
845
src/diffusers/models/controlnets/controlnet_z_image.py
Normal file
845
src/diffusers/models/controlnets/controlnet_z_image.py
Normal file
@@ -0,0 +1,845 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.normalization import RMSNorm
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..controlnets.controlnet import zero_module
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, mid_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(mid_size, out_size, bias=True),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
weight_dtype = self.mlp[0].weight.dtype
|
||||
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
|
||||
if weight_dtype.is_floating_point:
|
||||
t_freq = t_freq.to(weight_dtype)
|
||||
elif compute_dtype is not None:
|
||||
t_freq = t_freq.to(compute_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor
|
||||
class ZSingleStreamAttnProcessor:
|
||||
"""
|
||||
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
||||
original Z-ImageAttention module.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
output = attn.to_out[0](hidden_states)
|
||||
if len(attn.to_out) > 1: # dropout
|
||||
output = attn.to_out[1](output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.FeedForward
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // n_heads,
|
||||
heads=n_heads,
|
||||
qk_norm="rms_norm" if qk_norm else None,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ZSingleStreamAttnProcessor(),
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 256.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (64, 128, 128),
|
||||
):
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
else:
|
||||
# Ensure freqs_cis are on the same device as ids
|
||||
if self.freqs_cis[0].device != device:
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ZImageControlTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
block_id=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // n_heads,
|
||||
heads=n_heads,
|
||||
qk_norm="rms_norm" if qk_norm else None,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ZSingleStreamAttnProcessor(),
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||
|
||||
# Control variant start
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
|
||||
self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Control
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
|
||||
# Compared to `ZImageTransformerBlock` x -> c
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
c = c + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
c = c + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c)))
|
||||
|
||||
# Control
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
control_layers_places: List[int] = None,
|
||||
control_refiner_layers_places: List[int] = None,
|
||||
control_in_dim=None,
|
||||
add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None,
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
dim=3840,
|
||||
n_refiner_layers=2,
|
||||
n_heads=30,
|
||||
n_kv_heads=30,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.control_layers_places = control_layers_places
|
||||
self.control_in_dim = control_in_dim
|
||||
self.control_refiner_layers_places = control_refiner_layers_places
|
||||
self.add_control_noise_refiner = add_control_noise_refiner
|
||||
|
||||
assert 0 in self.control_layers_places
|
||||
|
||||
# control blocks
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i)
|
||||
for i in self.control_layers_places
|
||||
]
|
||||
)
|
||||
|
||||
# control patch embeddings
|
||||
all_x_embedder = {}
|
||||
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
if self.add_control_noise_refiner == "control_layers":
|
||||
self.control_noise_refiner = None
|
||||
elif self.add_control_noise_refiner == "control_noise_refiner":
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
block_id=layer_id,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.t_scale: Optional[float] = None
|
||||
self.t_embedder: Optional[TimestepEmbedder] = None
|
||||
self.all_x_embedder: Optional[nn.ModuleDict] = None
|
||||
self.cap_embedder: Optional[nn.Sequential] = None
|
||||
self.rope_embedder: Optional[RopeEmbedder] = None
|
||||
self.noise_refiner: Optional[nn.ModuleList] = None
|
||||
self.context_refiner: Optional[nn.ModuleList] = None
|
||||
self.x_pad_token: Optional[nn.Parameter] = None
|
||||
self.cap_pad_token: Optional[nn.Parameter] = None
|
||||
|
||||
@classmethod
|
||||
def from_transformer(cls, controlnet, transformer):
|
||||
controlnet.t_scale = transformer.t_scale
|
||||
controlnet.t_embedder = transformer.t_embedder
|
||||
controlnet.all_x_embedder = transformer.all_x_embedder
|
||||
controlnet.cap_embedder = transformer.cap_embedder
|
||||
controlnet.rope_embedder = transformer.rope_embedder
|
||||
controlnet.noise_refiner = transformer.noise_refiner
|
||||
controlnet.context_refiner = transformer.context_refiner
|
||||
controlnet.x_pad_token = transformer.x_pad_token
|
||||
controlnet.cap_pad_token = transformer.cap_pad_token
|
||||
return controlnet
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
|
||||
def patchify_and_embed(
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
all_image_out = []
|
||||
|
||||
for i, image in enumerate(all_image):
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return all_image_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
control_context: List[torch.Tensor],
|
||||
conditioning_scale: float = 1.0,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
):
|
||||
if (
|
||||
self.t_scale is None
|
||||
or self.t_embedder is None
|
||||
or self.all_x_embedder is None
|
||||
or self.cap_embedder is None
|
||||
or self.rope_embedder is None
|
||||
or self.noise_refiner is None
|
||||
or self.context_refiner is None
|
||||
or self.x_pad_token is None
|
||||
or self.cap_pad_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Required modules are `None`, use `from_transformer` to share required modules from `transformer`."
|
||||
)
|
||||
|
||||
assert patch_size in self.config.all_patch_size
|
||||
assert f_patch_size in self.config.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
control_context = self.patchify(control_context, patch_size, f_patch_size)
|
||||
control_context = torch.cat(control_context, dim=0)
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
||||
|
||||
control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
control_context = list(control_context.split(x_item_seqlens, dim=0))
|
||||
|
||||
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
||||
|
||||
# x embed & refine
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
||||
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if self.add_control_noise_refiner is not None:
|
||||
if self.add_control_noise_refiner == "control_layers":
|
||||
layers = self.control_layers
|
||||
elif self.add_control_noise_refiner == "control_noise_refiner":
|
||||
layers = self.control_noise_refiner
|
||||
else:
|
||||
raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.")
|
||||
for layer in layers:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
control_context = self._gradient_checkpointing_func(
|
||||
layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
hints = torch.unbind(control_context)[:-1]
|
||||
control_context = torch.unbind(control_context)[-1]
|
||||
noise_refiner_block_samples = {
|
||||
layer_idx: hints[idx] * conditioning_scale
|
||||
for idx, layer_idx in enumerate(self.control_refiner_layers_places)
|
||||
}
|
||||
else:
|
||||
noise_refiner_block_samples = None
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer_idx, layer in enumerate(self.noise_refiner):
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
if noise_refiner_block_samples is not None:
|
||||
if layer_idx in noise_refiner_block_samples:
|
||||
x = x + noise_refiner_block_samples[layer_idx]
|
||||
else:
|
||||
for layer_idx, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
if noise_refiner_block_samples is not None:
|
||||
if layer_idx in noise_refiner_block_samples:
|
||||
x = x + noise_refiner_block_samples[layer_idx]
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(
|
||||
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
||||
)
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
|
||||
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
## ControlNet start
|
||||
if not self.add_control_noise_refiner:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = self._gradient_checkpointing_func(
|
||||
layer, control_context, x_attn_mask, x_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
# unified
|
||||
control_context_unified = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
||||
|
||||
for layer in self.control_layers:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
control_context_unified = self._gradient_checkpointing_func(
|
||||
layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
control_context_unified = layer(
|
||||
control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
|
||||
hints = torch.unbind(control_context_unified)[:-1]
|
||||
controlnet_block_samples = {
|
||||
layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)
|
||||
}
|
||||
return controlnet_block_samples
|
||||
@@ -439,6 +439,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
concat_padding_mask: bool = True,
|
||||
extra_pos_embed_type: Optional[str] = "learnable",
|
||||
use_crossattn_projection: bool = False,
|
||||
crossattn_proj_in_channels: int = 1024,
|
||||
encoder_hidden_states_channels: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
@@ -485,6 +488,12 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
|
||||
)
|
||||
|
||||
if self.config.use_crossattn_projection:
|
||||
self.crossattn_proj = nn.Sequential(
|
||||
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
@@ -524,6 +533,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
||||
|
||||
@@ -546,6 +556,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
else:
|
||||
assert False
|
||||
|
||||
if self.config.use_crossattn_projection:
|
||||
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
|
||||
|
||||
# 5. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -143,17 +143,26 @@ def apply_rotary_emb_qwen(
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if use_additional_t_cond:
|
||||
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
||||
|
||||
conditioning = timesteps_emb
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
|
||||
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
|
||||
conditioning = conditioning + addition_t_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
@@ -259,6 +268,120 @@ class QwenEmbedRope(nn.Module):
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class QwenEmbedLayer3DRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
layer_num = len(video_fhw) - 1
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
if idx != layer_num:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
else:
|
||||
### For the condition image, we set the layer index to -1
|
||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||
video_freq = video_freq.to(device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
||||
@@ -578,14 +701,21 @@ class QwenImageTransformer2DModel(
|
||||
guidance_embeds: bool = False, # TODO: this should probably be removed
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
zero_cond_t: bool = False,
|
||||
use_additional_t_cond: bool = False,
|
||||
use_layer3d_rope: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
if not use_layer3d_rope:
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
else:
|
||||
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond
|
||||
)
|
||||
|
||||
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
||||
|
||||
@@ -621,6 +751,7 @@ class QwenImageTransformer2DModel(
|
||||
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
additional_t_cond=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
@@ -683,9 +814,9 @@ class QwenImageTransformer2DModel(
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -32,6 +32,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
X_PAD_DIM = 64
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -152,6 +153,20 @@ class ZSingleStreamAttnProcessor:
|
||||
return output
|
||||
|
||||
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
@@ -215,12 +230,37 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
@@ -252,9 +292,21 @@ class FinalLayer(nn.Module):
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation
|
||||
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
|
||||
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
|
||||
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Original global modulation
|
||||
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
scale = scale.unsqueeze(1)
|
||||
|
||||
x = self.norm_final(x) * scale
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -325,6 +377,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
@@ -386,6 +439,31 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
|
||||
|
||||
# Optional SigLIP components (for Omni variant)
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
2000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
@@ -402,252 +480,561 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
def unpatchify(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
size: List[Tuple],
|
||||
patch_size,
|
||||
f_patch_size,
|
||||
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
if x_pos_offsets is not None:
|
||||
# Omni: extract target image from unified sequence (cond_images + target)
|
||||
result = []
|
||||
for i in range(bsz):
|
||||
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
|
||||
cu_len = 0
|
||||
x_item = None
|
||||
for j in range(len(size[i])):
|
||||
if size[i][j] is None:
|
||||
ori_len = 0
|
||||
pad_len = SEQ_MULTI_OF
|
||||
cu_len += pad_len + ori_len
|
||||
else:
|
||||
F, H, W = size[i][j]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
x_item = (
|
||||
unified_x[cu_len : cu_len + ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
cu_len += ori_len + pad_len
|
||||
result.append(x_item) # Return only the last (target) image
|
||||
return result
|
||||
else:
|
||||
# Original mode: simple unpatchify
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
def patchify_and_embed(
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
cap_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_pad_mask.append(
|
||||
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padded_pos_ids = torch.cat(
|
||||
[
|
||||
image_ori_pos_ids,
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1),
|
||||
],
|
||||
dim=0,
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
||||
# pad mask
|
||||
image_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_image_pad_mask.append(
|
||||
image_pad_mask
|
||||
if image_padding_len > 0
|
||||
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat(
|
||||
[image, image[-1:].repeat(image_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
def patchify_and_embed_omni(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
return_dict: bool = True,
|
||||
all_x: List[List[torch.Tensor]],
|
||||
all_cap_feats: List[List[torch.Tensor]],
|
||||
all_siglip_feats: List[List[torch.Tensor]],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
images_noise_mask: List[List[int]],
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
"""Patchify for omni mode: multiple images per batch item with noise masks."""
|
||||
bsz = len(all_x)
|
||||
device = all_x[0][-1].device
|
||||
dtype = all_x[0][-1].dtype
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
|
||||
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
for i in range(bsz):
|
||||
num_images = len(all_x[i])
|
||||
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
|
||||
cap_end_pos = []
|
||||
cap_cu_len = 1
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
# Process captions
|
||||
for j, cap_item in enumerate(all_cap_feats[i]):
|
||||
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
|
||||
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
|
||||
cap_item,
|
||||
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
|
||||
(cap_cu_len, 0, 0),
|
||||
device,
|
||||
noise_val,
|
||||
)
|
||||
cap_feats_list.append(cap_out)
|
||||
cap_pos_list.append(cap_pos)
|
||||
cap_mask_list.append(cap_mask)
|
||||
cap_lens.append(cap_len)
|
||||
cap_noise.extend(cap_nm)
|
||||
cap_cu_len += len(cap_item)
|
||||
cap_end_pos.append(cap_cu_len)
|
||||
cap_cu_len += 2 # for image vae and siglip tokens
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
|
||||
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
|
||||
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
|
||||
all_cap_len.append(cap_lens)
|
||||
all_cap_noise_mask.append(cap_noise)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
||||
# Process images
|
||||
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
|
||||
for j, x_item in enumerate(all_x[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if x_item is not None:
|
||||
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
|
||||
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
|
||||
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
|
||||
)
|
||||
x_size.append(size)
|
||||
else:
|
||||
x_len = SEQ_MULTI_OF
|
||||
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
|
||||
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
|
||||
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
|
||||
x_nm = [noise_val] * x_len
|
||||
x_size.append(None)
|
||||
x_feats_list.append(x_out)
|
||||
x_pos_list.append(x_pos)
|
||||
x_mask_list.append(x_mask)
|
||||
x_lens.append(x_len)
|
||||
x_noise.extend(x_nm)
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
||||
all_x_out.append(torch.cat(x_feats_list, dim=0))
|
||||
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
|
||||
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
|
||||
all_x_size.append(x_size)
|
||||
all_x_len.append(x_lens)
|
||||
all_x_noise_mask.append(x_noise)
|
||||
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
# Process siglip
|
||||
if all_siglip_feats[i] is None:
|
||||
all_sig_len.append([0] * num_images)
|
||||
all_sig_out.append(None)
|
||||
else:
|
||||
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
|
||||
for j, sig_item in enumerate(all_siglip_feats[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if sig_item is not None:
|
||||
sig_H, sig_W, sig_C = sig_item.size()
|
||||
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
|
||||
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
|
||||
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
|
||||
)
|
||||
# Scale position IDs to match x resolution
|
||||
if x_size[j] is not None:
|
||||
sig_pos = sig_pos.float()
|
||||
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
|
||||
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
|
||||
sig_pos = sig_pos.to(torch.int32)
|
||||
else:
|
||||
sig_len = SEQ_MULTI_OF
|
||||
sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device)
|
||||
sig_pos = (
|
||||
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
|
||||
)
|
||||
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
|
||||
sig_nm = [noise_val] * sig_len
|
||||
sig_feats_list.append(sig_out)
|
||||
sig_pos_list.append(sig_pos)
|
||||
sig_mask_list.append(sig_mask)
|
||||
sig_lens.append(sig_len)
|
||||
sig_noise.extend(sig_nm)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
|
||||
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
|
||||
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
|
||||
all_sig_len.append(sig_lens)
|
||||
all_sig_noise_mask.append(sig_noise)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
# Compute x position offsets
|
||||
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(
|
||||
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
||||
return (
|
||||
all_x_out,
|
||||
all_cap_out,
|
||||
all_sig_out,
|
||||
all_x_size,
|
||||
all_x_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_sig_pos_ids,
|
||||
all_x_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
all_sig_pad_mask,
|
||||
all_x_pos_offsets,
|
||||
all_x_noise_mask,
|
||||
all_cap_noise_mask,
|
||||
all_sig_noise_mask,
|
||||
)
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
feats: List[torch.Tensor],
|
||||
pos_ids: List[torch.Tensor],
|
||||
inner_pad_mask: List[torch.Tensor],
|
||||
pad_token: torch.nn.Parameter,
|
||||
noise_mask: Optional[List[List[int]]] = None,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
|
||||
item_seqlens = [len(f) for f in feats]
|
||||
max_seqlen = max(item_seqlens)
|
||||
bsz = len(feats)
|
||||
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
# RoPE
|
||||
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
|
||||
|
||||
# unified
|
||||
# Pad to batch
|
||||
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if noise_mask is not None:
|
||||
noise_mask_tensor = pad_sequence(
|
||||
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)[:, : feats.shape[1]]
|
||||
|
||||
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
|
||||
|
||||
def _build_unified_sequence(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_freqs: torch.Tensor,
|
||||
x_seqlens: List[int],
|
||||
x_noise_mask: Optional[List[List[int]]],
|
||||
cap: torch.Tensor,
|
||||
cap_freqs: torch.Tensor,
|
||||
cap_seqlens: List[int],
|
||||
cap_noise_mask: Optional[List[List[int]]],
|
||||
siglip: Optional[torch.Tensor],
|
||||
siglip_freqs: Optional[torch.Tensor],
|
||||
siglip_seqlens: Optional[List[int]],
|
||||
siglip_noise_mask: Optional[List[List[int]]],
|
||||
omni_mode: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Build unified sequence: x, cap, and optionally siglip.
|
||||
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
|
||||
"""
|
||||
bsz = len(x_seqlens)
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
unified_freqs = []
|
||||
unified_noise_mask = []
|
||||
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
if omni_mode:
|
||||
# Omni: [cap, x, siglip]
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
sig_len = siglip_seqlens[i]
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
|
||||
unified_freqs.append(
|
||||
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
|
||||
)
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(
|
||||
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
|
||||
)
|
||||
)
|
||||
else:
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
|
||||
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
|
||||
)
|
||||
else:
|
||||
# Basic: [x, cap]
|
||||
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
|
||||
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.layers:
|
||||
unified = self._gradient_checkpointing_func(
|
||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
# Compute unified seqlens
|
||||
if omni_mode:
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
|
||||
else:
|
||||
for layer in self.layers:
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
max_seqlen = max(unified_seqlens)
|
||||
|
||||
if not return_dict:
|
||||
return (x,)
|
||||
# Pad to batch
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
return Transformer2DModelOutput(sample=x)
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if omni_mode:
|
||||
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
|
||||
:, : unified.shape[1]
|
||||
]
|
||||
|
||||
return unified, unified_freqs, attn_mask, noise_mask_tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[List[torch.Tensor], List[List[torch.Tensor]]],
|
||||
t,
|
||||
cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]],
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
|
||||
siglip_feats: Optional[List[List[torch.Tensor]]] = None,
|
||||
image_noise_mask: Optional[List[List[int]]] = None,
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine
|
||||
-> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify
|
||||
"""
|
||||
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
|
||||
omni_mode = isinstance(x[0], list)
|
||||
device = x[0][-1].device if omni_mode else x[0].device
|
||||
|
||||
if omni_mode:
|
||||
# Dual embeddings: noisy (t) and clean (t=1)
|
||||
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
|
||||
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
|
||||
adaln_input = None
|
||||
else:
|
||||
# Single embedding for all tokens
|
||||
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
|
||||
t_noisy = t_clean = None
|
||||
|
||||
# Patchify
|
||||
if omni_mode:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
siglip_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
siglip_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
siglip_pad_mask,
|
||||
x_pos_offsets,
|
||||
x_noise_mask,
|
||||
cap_noise_mask,
|
||||
siglip_noise_mask,
|
||||
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
|
||||
else:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
|
||||
|
||||
# X embed & refine
|
||||
x_seqlens = [len(xi) for xi in x]
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
|
||||
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
|
||||
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
|
||||
)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = (
|
||||
self._gradient_checkpointing_func(
|
||||
layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean)
|
||||
)
|
||||
|
||||
# Cap embed & refine
|
||||
cap_seqlens = [len(ci) for ci in cap_feats]
|
||||
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
|
||||
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
|
||||
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = (
|
||||
self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(cap_feats, cap_mask, cap_freqs)
|
||||
)
|
||||
|
||||
# Siglip embed & refine
|
||||
siglip_seqlens = siglip_freqs = None
|
||||
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
|
||||
siglip_seqlens = [len(si) for si in siglip_feats]
|
||||
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
|
||||
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
|
||||
list(siglip_feats.split(siglip_seqlens, dim=0)),
|
||||
siglip_pos_ids,
|
||||
siglip_pad_mask,
|
||||
self.siglip_pad_token,
|
||||
None,
|
||||
device,
|
||||
)
|
||||
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats = (
|
||||
self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(siglip_feats, siglip_mask, siglip_freqs)
|
||||
)
|
||||
|
||||
# Unified sequence
|
||||
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
|
||||
x,
|
||||
x_freqs,
|
||||
x_seqlens,
|
||||
x_noise_mask,
|
||||
cap_feats,
|
||||
cap_freqs,
|
||||
cap_seqlens,
|
||||
cap_noise_mask,
|
||||
siglip_feats,
|
||||
siglip_freqs,
|
||||
siglip_seqlens,
|
||||
siglip_noise_mask,
|
||||
omni_mode,
|
||||
device,
|
||||
)
|
||||
|
||||
# Main transformer layers
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = (
|
||||
self._gradient_checkpointing_func(
|
||||
layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean)
|
||||
)
|
||||
if controlnet_block_samples is not None and layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
|
||||
unified = (
|
||||
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
|
||||
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
|
||||
)
|
||||
if omni_mode
|
||||
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
|
||||
)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
|
||||
|
||||
return (x,) if not return_dict else Transformer2DModelOutput(sample=x)
|
||||
|
||||
@@ -360,7 +360,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("image_encoder", FluxAutoVaeEncoderStep()),
|
||||
("vae_encoder", FluxAutoVaeEncoderStep()),
|
||||
("denoise", FluxCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
@@ -369,7 +369,7 @@ AUTO_BLOCKS = InsertableDict(
|
||||
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("image_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("denoise", FluxKontextCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -501,15 +501,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs]
|
||||
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
||||
|
||||
@property
|
||||
def intermediate_output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.intermediate_outputs]
|
||||
return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None]
|
||||
|
||||
@property
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
return [output_param.name for output_param in self.outputs if output_param.name is not None]
|
||||
|
||||
@property
|
||||
def component_names(self) -> List[str]:
|
||||
return [component.name for component in self.expected_components]
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
@@ -1525,10 +1529,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if blocks is None:
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
elif config_dict is not None:
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
else:
|
||||
blocks_class_name = None
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
@@ -1625,7 +1627,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
return None, config_dict
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
raise EnvironmentError(
|
||||
f"Failed to load config from '{pretrained_model_name_or_path}'. "
|
||||
f"Could not find or load 'modular_model_index.json' or 'model_index.json'."
|
||||
) from e
|
||||
|
||||
return None, None
|
||||
|
||||
@@ -2550,7 +2555,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||
elif name not in state.values:
|
||||
elif kwargs_type is not None and kwargs_type in passed_kwargs:
|
||||
kwargs_dict = passed_kwargs.pop(kwargs_type)
|
||||
for k, v in kwargs_dict.items():
|
||||
state.set(k, v, kwargs_type)
|
||||
elif name is not None and name not in state.values:
|
||||
state.set(name, default, kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
|
||||
@@ -30,6 +30,47 @@ from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -41,7 +82,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
return components
|
||||
@@ -49,8 +89,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
@@ -74,10 +112,12 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
if block_state.latents.ndim == 4:
|
||||
block_state.latents = block_state.latents.unsqueeze(dim=1)
|
||||
elif block_state.latents.ndim != 5:
|
||||
raise ValueError(
|
||||
f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step."
|
||||
)
|
||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||
|
||||
latents_mean = (
|
||||
|
||||
@@ -26,7 +26,12 @@ from .before_denoise import (
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
@@ -92,6 +97,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -205,6 +211,7 @@ INPAINT_BLOCKS = InsertableDict(
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageInpaintDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -264,6 +271,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -529,8 +537,16 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageAutoBeforeDenoiseStep,
|
||||
QwenImageOptionalControlNetBeforeDenoiseStep,
|
||||
QwenImageAutoDenoiseStep,
|
||||
QwenImageAfterDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"before_denoise",
|
||||
"controlnet_before_denoise",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -653,6 +669,7 @@ EDIT_BLOCKS = InsertableDict(
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -702,6 +719,7 @@ EDIT_INPAINT_BLOCKS = InsertableDict(
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditInpaintDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -841,8 +859,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageEditAutoInputStep,
|
||||
QwenImageEditAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
QwenImageAfterDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -954,6 +973,7 @@ EDIT_PLUS_BLOCKS = InsertableDict(
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("after_denoise", QwenImageAfterDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -1037,8 +1057,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageEditPlusAutoInputStep,
|
||||
QwenImageEditPlusAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
QwenImageAfterDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# mellon nodes
|
||||
QwenImage_NODE_TYPES_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet_out",
|
||||
],
|
||||
"block_names": ["controlnet_vae_encoder"],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
"controlnet",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
SDXL_NODE_TYPES_PARAMS_MAP = {
|
||||
"controlnet": {
|
||||
"inputs": [
|
||||
"control_image",
|
||||
"controlnet_conditioning_scale",
|
||||
"control_guidance_start",
|
||||
"control_guidance_end",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
"model_inputs": [
|
||||
"controlnet",
|
||||
],
|
||||
"outputs": [
|
||||
"controlnet_out",
|
||||
],
|
||||
"block_names": [None],
|
||||
},
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
"embeddings",
|
||||
"width",
|
||||
"height",
|
||||
"seed",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"image_latents",
|
||||
"strength",
|
||||
# custom adapters coming in as inputs
|
||||
"controlnet",
|
||||
# ip_adapter is optional and custom; include if available
|
||||
"ip_adapter",
|
||||
],
|
||||
"model_inputs": [
|
||||
"unet",
|
||||
"guider",
|
||||
"scheduler",
|
||||
],
|
||||
"outputs": [
|
||||
"latents",
|
||||
"latents_preview",
|
||||
],
|
||||
"block_names": ["denoise"],
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"image_latents",
|
||||
],
|
||||
"block_names": ["vae_encoder"],
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
],
|
||||
"model_inputs": [
|
||||
"text_encoders",
|
||||
],
|
||||
"outputs": [
|
||||
"embeddings",
|
||||
],
|
||||
"block_names": ["text_encoder"],
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
"latents",
|
||||
],
|
||||
"model_inputs": [
|
||||
"vae",
|
||||
],
|
||||
"outputs": [
|
||||
"images",
|
||||
],
|
||||
"block_names": ["decode"],
|
||||
},
|
||||
}
|
||||
@@ -129,6 +129,10 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
uncond_guider_input_names = []
|
||||
|
||||
@@ -119,7 +119,7 @@ class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
|
||||
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [ZImageVaeImageEncoderStep]
|
||||
block_names = ["vae_image_encoder"]
|
||||
block_names = ["vae_encoder"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
@@ -137,7 +137,7 @@ class ZImageAutoBlocks(SequentialPipelineBlocks):
|
||||
ZImageAutoDenoiseStep,
|
||||
ZImageVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -162,7 +162,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_image_encoder", ZImageVaeImageEncoderStep),
|
||||
("vae_encoder", ZImageVaeImageEncoderStep),
|
||||
("input", ZImageTextInputStep),
|
||||
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
||||
("prepare_latents", ZImagePrepareLatentsStep),
|
||||
@@ -178,7 +178,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
|
||||
("vae_encoder", ZImageAutoVaeImageEncoderStep),
|
||||
("denoise", ZImageAutoDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
|
||||
@@ -165,6 +165,7 @@ else:
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["cosmos"] = [
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
@@ -405,7 +406,13 @@ else:
|
||||
"Kandinsky5T2IPipeline",
|
||||
"Kandinsky5I2IPipeline",
|
||||
]
|
||||
_import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline"]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImagePipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
|
||||
@@ -422,6 +429,7 @@ else:
|
||||
"QwenImageEditInpaintPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageLayeredPipeline",
|
||||
]
|
||||
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
|
||||
try:
|
||||
@@ -616,6 +624,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .cosmos import (
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
@@ -764,6 +773,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImageLayeredPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import (
|
||||
@@ -843,7 +853,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
|
||||
from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
|
||||
@@ -73,6 +73,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaPipeline
|
||||
from .lumina2 import Lumina2Pipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
@@ -119,7 +120,13 @@ from .stable_diffusion_xl import (
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
|
||||
from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
@@ -164,6 +171,10 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
("z-image", ZImagePipeline),
|
||||
("z-image-controlnet", ZImageControlNetPipeline),
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -229,7 +229,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
],
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -22,6 +22,9 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cosmos2_5_predict"] = [
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
]
|
||||
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
|
||||
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||
@@ -35,6 +38,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_cosmos2_5_predict import (
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
)
|
||||
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
|
||||
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
|
||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||
|
||||
847
src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
Normal file
847
src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
Normal file
@@ -0,0 +1,847 @@
|
||||
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_PredictBasePipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||
... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # Common negative prompt reused across modes.
|
||||
>>> negative_prompt = (
|
||||
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
|
||||
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
... "Overall, the video is of poor quality."
|
||||
... )
|
||||
|
||||
>>> # Text2World: generate a 93-frame world video from text only.
|
||||
>>> prompt = (
|
||||
... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights "
|
||||
... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh "
|
||||
... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet "
|
||||
... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. "
|
||||
... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow "
|
||||
... "advance of traffic through the frosty city corridor."
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... image=None,
|
||||
... video=None,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "text2world.mp4", fps=16)
|
||||
|
||||
>>> # Image2World: condition on a single image and generate a 93-frame world video.
|
||||
>>> prompt = (
|
||||
... "A high-definition video captures the precision of robotic welding in an industrial setting. "
|
||||
... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. "
|
||||
... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid "
|
||||
... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring "
|
||||
... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a "
|
||||
... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video "
|
||||
... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. "
|
||||
... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. "
|
||||
... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with "
|
||||
... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation."
|
||||
... )
|
||||
>>> image = load_image(
|
||||
... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg"
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... image=image,
|
||||
... video=None,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "image2world.mp4", fps=16)
|
||||
|
||||
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
|
||||
>>> prompt = (
|
||||
... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles "
|
||||
... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the "
|
||||
... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green "
|
||||
... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. "
|
||||
... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along "
|
||||
... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame "
|
||||
... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet "
|
||||
... "steady pace of the construction activity."
|
||||
... )
|
||||
>>> input_video = load_video(
|
||||
... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4"
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... image=None,
|
||||
... video=input_video,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "video2world.mp4", fps=16)
|
||||
|
||||
>>> # To produce an image instead of a world (video) clip, set num_frames=1 and
|
||||
>>> # save the first frame: pipe(..., num_frames=1).frames[0][0].
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
|
||||
Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5
|
||||
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
Tokenizer associated with the Qwen2.5 VL encoder.
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: UniPCMultistepScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||
if getattr(self.vae.config, "latents_mean", None) is not None
|
||||
else None
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||
if getattr(self.vae.config, "latents_std", None) is not None
|
||||
else None
|
||||
)
|
||||
self.latents_mean = latents_mean
|
||||
self.latents_std = latents_std
|
||||
|
||||
if self.latents_mean is None or self.latents_std is None:
|
||||
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
|
||||
|
||||
def _get_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
input_ids_batch = []
|
||||
|
||||
for sample_idx in range(len(prompt)):
|
||||
conversations = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant who will provide prompts to an image generator.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt[sample_idx],
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
add_vision_id=False,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
input_ids_batch = torch.stack(input_ids_batch, dim=0)
|
||||
|
||||
outputs = self.text_encoder(
|
||||
input_ids_batch.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
normalized_hidden_states = []
|
||||
for layer_idx in range(1, len(hidden_states)):
|
||||
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
|
||||
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
|
||||
)
|
||||
normalized_hidden_states.append(normalized_state)
|
||||
|
||||
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
|
||||
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor],
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames_in: int = 93,
|
||||
num_frames_out: int = 93,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
B = batch_size
|
||||
C = num_channels_latents
|
||||
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||
H = height // self.vae_scale_factor_spatial
|
||||
W = width // self.vae_scale_factor_spatial
|
||||
shape = (B, C, T, H, W)
|
||||
|
||||
if num_frames_in == 0:
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||
|
||||
cond_latents = torch.zeros_like(latents)
|
||||
|
||||
return (
|
||||
latents,
|
||||
cond_latents,
|
||||
cond_mask,
|
||||
cond_indicator,
|
||||
)
|
||||
else:
|
||||
if video is None:
|
||||
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
|
||||
needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3)
|
||||
if needs_preprocessing:
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
cond_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
padding_shape = (B, 1, T, H, W)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
return (
|
||||
latents,
|
||||
cond_latents,
|
||||
cond_mask,
|
||||
cond_indicator,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput | None = None,
|
||||
video: List[PipelineImageInput] | None = None,
|
||||
prompt: Union[str, List[str]] | None = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 93,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 7.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
|
||||
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
||||
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
||||
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
||||
|
||||
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
||||
above in "*2Image mode").
|
||||
|
||||
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||
height (`int`, defaults to `704`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `93`):
|
||||
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `7.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
|
||||
# Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
num_frames_in = None
|
||||
if image is not None:
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
||||
|
||||
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
||||
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
||||
video = video.unsqueeze(0)
|
||||
num_frames_in = 1
|
||||
elif video is None:
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
num_frames_out = num_frames
|
||||
if video.shape[2] < num_frames_out:
|
||||
n_pad_frames = num_frames_out - num_frames_in
|
||||
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
|
||||
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
|
||||
video = torch.cat((video, pad_frames), dim=2)
|
||||
|
||||
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
||||
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=num_frames_in,
|
||||
num_frames_out=num_frames,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
)
|
||||
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
condition_mask=cond_mask,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
condition_mask=cond_mask,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
video = self._match_num_frames(video, num_frames)
|
||||
|
||||
assert self.safety_checker is not None
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
|
||||
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
||||
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
||||
return video
|
||||
|
||||
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
||||
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
||||
|
||||
current_frames = video.shape[2]
|
||||
if current_frames < target_num_frames:
|
||||
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
elif current_frames > target_num_frames:
|
||||
video = video[:, :, :target_num_frames]
|
||||
|
||||
return video
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -204,7 +204,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -208,7 +208,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -31,6 +31,7 @@ else:
|
||||
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
|
||||
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
|
||||
_import_structure["pipeline_qwenimage_layered"] = ["QwenImageLayeredPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -47,6 +48,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
||||
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
|
||||
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
|
||||
from .pipeline_qwenimage_layered import QwenImageLayeredPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
905
src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
Normal file
905
src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
Normal file
@@ -0,0 +1,905 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import QwenImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import QwenImageLayeredPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = QwenImageLayeredPipeline.from_pretrained("Qwen/Qwen-Image-Layered", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
||||
... ).convert("RGBA")
|
||||
>>> prompt = ""
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> images = pipe(
|
||||
... image,
|
||||
... prompt,
|
||||
... num_inference_steps=50,
|
||||
... true_cfg_scale=4.0,
|
||||
... layers=4,
|
||||
... resolution=640,
|
||||
... cfg_normalize=False,
|
||||
... use_en_prompt=True,
|
||||
... ).images[0]
|
||||
>>> for i, image in enumerate(images):
|
||||
... image.save(f"{i}.out.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions
|
||||
def calculate_dimensions(target_area, ratio):
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
return width, height
|
||||
|
||||
|
||||
class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
r"""
|
||||
The Qwen-Image-Layered pipeline for image decomposing.
|
||||
|
||||
Args:
|
||||
transformer ([`QwenImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`QwenTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
processor: Qwen2VLProcessor,
|
||||
transformer: QwenImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.vl_processor = processor
|
||||
self.tokenizer_max_length = 1024
|
||||
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.prompt_template_encode_start_idx = 34
|
||||
self.image_caption_prompt_cn = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# 图像标注器\n你是一个专业的图像标注器。请基于输入图像,撰写图注:\n1.
|
||||
使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n2. 通过加入以下内容,丰富图注细节:\n - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n -
|
||||
对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n - 环境细节:例如天气、光照、颜色、纹理、气氛等\n - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n3.
|
||||
保持真实性与准确性:\n - 不要使用笼统的描述\n -
|
||||
描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n"""
|
||||
self.image_caption_prompt_en = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# Image Annotator\nYou are a professional
|
||||
image annotator. Please write an image caption based on the input image:\n1. Write the caption using natural,
|
||||
descriptive language without structured formats or rich text.\n2. Enrich caption details by including: \n - Object
|
||||
attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n - Vision Relations
|
||||
between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action
|
||||
relations, comparative relations, causal relations, and so on\n - Environmental details, such as weather, lighting,
|
||||
colors, textures, atmosphere, and so on\n - Identify the text clearly visible in the image, without translation or
|
||||
explanation, and highlight it in the caption with quotation marks\n3. Maintain authenticity and accuracy:\n - Avoid
|
||||
generalizations\n - Describe all visible information in the image, while do not add information not explicitly shown in
|
||||
the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n"""
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
||||
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
|
||||
return split_result
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = self.prompt_template_encode
|
||||
drop_idx = self.prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = self.tokenizer(
|
||||
txt,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
encoder_hidden_states = self.text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_hidden_states.hidden_states[-1]
|
||||
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):
|
||||
if use_en_prompt:
|
||||
prompt = self.image_caption_prompt_en
|
||||
else:
|
||||
prompt = self.image_caption_prompt_cn
|
||||
model_inputs = self.vl_processor(
|
||||
text=prompt,
|
||||
images=prompt_image,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
generated_ids = self.text_encoder.generate(**model_inputs, max_new_tokens=512)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = self.vl_processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
return output_text.strip()
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers):
|
||||
latents = latents.view(batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
|
||||
latents = latents.reshape(batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, layers, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
|
||||
latents = latents.view(batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
|
||||
|
||||
latents = latents.reshape(batch_size, layers + 1, channels // (2 * 2), height, width)
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w)
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
image_latents = (image_latents - latents_mean) / latents_std
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
layers,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
layers + 1,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
) ### the generated first image is combined image
|
||||
|
||||
image_latents = None
|
||||
if image is not None:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
image_latent_height, image_latent_width = image_latents.shape[3:]
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) -> (b, f, c, h, w)
|
||||
image_latents = self._pack_latents(
|
||||
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width, 1
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, layers + 1)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents, image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
true_cfg_scale: float = 4.0,
|
||||
layers: Optional[int] = 4,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
resolution: int = 640,
|
||||
cfg_normalize: bool = False,
|
||||
use_en_prompt: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||
not greater than `1`).
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
|
||||
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
|
||||
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
|
||||
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
|
||||
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
|
||||
lower image quality.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
|
||||
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
|
||||
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
|
||||
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
|
||||
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
|
||||
enable classifier-free guidance computations).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
using different bucket in (640, 1024) to determin the condition and output resolution
|
||||
cfg_normalize (`bool`, *optional*, defaults to `False`)
|
||||
whether enable cfg normalization.
|
||||
use_en_prompt (`bool`, *optional*, defaults to `False`)
|
||||
automatic caption language if user does not provide caption
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
image_size = image[0].size if isinstance(image, list) else image.size
|
||||
assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}"
|
||||
calculated_width, calculated_height = calculate_dimensions(
|
||||
resolution * resolution, image_size[0] / image_size[1]
|
||||
)
|
||||
height = calculated_height
|
||||
width = calculated_width
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
width = width // multiple_of * multiple_of
|
||||
height = height // multiple_of * multiple_of
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
# 2. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
image = self.image_processor.resize(image, calculated_height, calculated_width)
|
||||
prompt_image = image
|
||||
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
|
||||
image = image.unsqueeze(2)
|
||||
image = image.to(dtype=self.text_encoder.dtype)
|
||||
|
||||
if prompt is None or prompt == "" or prompt == " ":
|
||||
prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device)
|
||||
|
||||
# 3. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
|
||||
if true_cfg_scale > 1 and not has_neg_prompt:
|
||||
logger.warning(
|
||||
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
||||
)
|
||||
elif true_cfg_scale <= 1 and has_neg_prompt:
|
||||
logger.warning(
|
||||
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
||||
)
|
||||
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_true_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, image_latents = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
layers,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
img_shapes = [
|
||||
[
|
||||
*[
|
||||
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)
|
||||
for _ in range(layers + 1)
|
||||
],
|
||||
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
base_seqlen = 256 * 256 / 16 / 16
|
||||
mu = (image_latents.shape[1] / base_seqlen) ** 0.5
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
||||
elif self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
||||
logger.warning(
|
||||
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
||||
)
|
||||
guidance = None
|
||||
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
latent_model_input = latents
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
if cfg_normalize:
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
else:
|
||||
noise_pred = comb_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, layers, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
|
||||
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||
|
||||
image = image.squeeze(2)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
images = []
|
||||
for bidx in range(b):
|
||||
images.append(image[bidx * f : (bidx + 1) * f])
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return QwenImagePipelineOutput(images=images)
|
||||
@@ -23,7 +23,10 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
|
||||
_import_structure["pipeline_z_image"] = ["ZImagePipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -36,8 +39,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
from .pipeline_z_image import ZImagePipeline
|
||||
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||
|
||||
from .pipeline_z_image_omni import ZImageOmniPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
725
src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
Normal file
725
src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
Normal file
@@ -0,0 +1,725 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets import ZImageControlNetModel
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageControlNetPipeline
|
||||
>>> from diffusers import ZImageControlNetModel
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
|
||||
>>> controlnet = ZImageControlNetModel.from_single_file(
|
||||
... hf_hub_download(
|
||||
... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union",
|
||||
... filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors",
|
||||
... ),
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
|
||||
>>> # 2.1
|
||||
>>> # controlnet = ZImageControlNetModel.from_single_file(
|
||||
>>> # hf_hub_download(
|
||||
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
|
||||
>>> # ),
|
||||
>>> # torch_dtype=torch.bfloat16,
|
||||
>>> # )
|
||||
|
||||
>>> # 2.0 - `config` is required
|
||||
>>> # controlnet = ZImageControlNetModel.from_single_file(
|
||||
>>> # hf_hub_download(
|
||||
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
|
||||
>>> # ),
|
||||
>>> # torch_dtype=torch.bfloat16,
|
||||
>>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||
>>> # )
|
||||
|
||||
>>> pipe = ZImageControlNetPipeline.from_pretrained(
|
||||
... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> control_image = load_image(
|
||||
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true"
|
||||
... )
|
||||
>>> prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... control_image=control_image,
|
||||
... controlnet_conditioning_scale=0.75,
|
||||
... height=1728,
|
||||
... width=992,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(43),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
controlnet: ZImageControlNetModel,
|
||||
):
|
||||
super().__init__()
|
||||
controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
control_image: PipelineImageInput = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 0.75,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax")
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
control_image = control_image.unsqueeze(2)
|
||||
|
||||
if num_channels_latents != self.controlnet.config.control_in_dim:
|
||||
# For model version 2.0
|
||||
control_image = torch.cat(
|
||||
[
|
||||
control_image,
|
||||
torch.zeros(
|
||||
control_image.shape[0],
|
||||
self.controlnet.config.control_in_dim - num_channels_latents,
|
||||
*control_image.shape[2:],
|
||||
).to(device=control_image.device, dtype=control_image.dtype),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
controlnet_block_samples = self.controlnet(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,747 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets import ZImageControlNetModel
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageControlNetInpaintPipeline
|
||||
>>> from diffusers import ZImageControlNetModel
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
|
||||
>>> controlnet = ZImageControlNetModel.from_single_file(
|
||||
... hf_hub_download(
|
||||
... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||
... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
|
||||
... ),
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
|
||||
>>> # 2.0 - `config` is required
|
||||
>>> # controlnet = ZImageControlNetModel.from_single_file(
|
||||
>>> # hf_hub_download(
|
||||
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
|
||||
>>> # ),
|
||||
>>> # torch_dtype=torch.bfloat16,
|
||||
>>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
|
||||
>>> # )
|
||||
|
||||
>>> pipe = ZImageControlNetInpaintPipeline.from_pretrained(
|
||||
... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/inpaint.jpg?download=true"
|
||||
... )
|
||||
>>> mask_image = load_image(
|
||||
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/mask.jpg?download=true"
|
||||
... )
|
||||
>>> control_image = load_image(
|
||||
... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/pose.jpg?download=true"
|
||||
... )
|
||||
>>> prompt = "一位年轻女子站在阳光明媚的海岸线上,画面为全身竖构图,身体微微侧向右侧,左手自然下垂,右臂弯曲扶在腰间,她的手指清晰可见,站姿放松而略带羞涩。她身穿轻盈的白色连衣裙,裙摆在海风中轻轻飘动,布料半透、质感柔软。女子拥有一头鲜艳的及腰紫色长发,被海风吹起,在身侧轻盈飞舞,发间系着一个精致的黑色蝴蝶结,与发色形成对比。她面容清秀,眉目精致,肤色白皙细腻,表情温柔略显羞涩,微微低头,眼神静静望向远处的海平线,流露出甜美的青春气息与若有所思的神情。背景是辽阔无垠的海洋与蔚蓝天空,阳光从侧前方洒下,海面波光粼粼,泛着温暖的金色光晕,天空清澈明亮,云朵稀薄,整体色调清新唯美。"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... image=image,
|
||||
... mask_image=mask_image,
|
||||
... control_image=control_image,
|
||||
... controlnet_conditioning_scale=0.75,
|
||||
... height=1728,
|
||||
... width=992,
|
||||
... num_inference_steps=25,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(43),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage-inpaint.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
controlnet: ZImageControlNetModel,
|
||||
):
|
||||
super().__init__()
|
||||
if transformer.in_channels == controlnet.config.control_in_dim:
|
||||
raise ValueError(
|
||||
"ZImageControlNetInpaintPipeline is not compatible with `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union`, use `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0`."
|
||||
)
|
||||
controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 0.75,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax")
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
control_image = control_image.unsqueeze(2)
|
||||
|
||||
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to(
|
||||
device=control_image.device, dtype=control_image.dtype
|
||||
)
|
||||
|
||||
init_image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = init_image.shape[-2:]
|
||||
init_image = init_image * (mask_condition < 0.5)
|
||||
init_image = retrieve_latents(self.vae.encode(init_image), generator=generator, sample_mode="argmax")
|
||||
init_image = (init_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
init_image = init_image.unsqueeze(2)
|
||||
|
||||
mask_condition = F.interpolate(1 - mask_condition[:, :1], size=init_image.size()[-2:], mode="nearest").to(
|
||||
device=control_image.device, dtype=control_image.dtype
|
||||
)
|
||||
mask_condition = mask_condition.unsqueeze(2)
|
||||
|
||||
control_image = torch.cat([control_image, mask_condition, init_image], dim=1)
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
controlnet_block_samples = self.controlnet(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
742
src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Normal file
742
src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel
|
||||
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..flux2.image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageOmniPipeline
|
||||
|
||||
>>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... height=1024,
|
||||
... width=1024,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
siglip: Siglip2VisionModel,
|
||||
siglip_processor: Siglip2ImageProcessorFast,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
siglip=siglip,
|
||||
siglip_processor=siglip_processor,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
# self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
num_condition_images: int = 0,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
num_condition_images: int = 0,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
if num_condition_images == 0:
|
||||
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
||||
elif num_condition_images > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
||||
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
prompt[i] = prompt_list
|
||||
|
||||
flattened_prompt = []
|
||||
prompt_list_lengths = []
|
||||
|
||||
for i in range(len(prompt)):
|
||||
prompt_list_lengths.append(len(prompt[i]))
|
||||
flattened_prompt.extend(prompt[i])
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
flattened_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
start_idx = 0
|
||||
for i in range(len(prompt_list_lengths)):
|
||||
batch_embeddings = []
|
||||
end_idx = start_idx + prompt_list_lengths[i]
|
||||
for j in range(start_idx, end_idx):
|
||||
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
||||
embeddings_list.append(batch_embeddings)
|
||||
start_idx = end_idx
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latent = (
|
||||
self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor
|
||||
) * self.vae.config.scaling_factor
|
||||
image_latent = image_latent.unsqueeze(1).to(dtype)
|
||||
image_latents.append(image_latent) # (16, 128, 128)
|
||||
|
||||
# image_latents = [image_latents] * batch_size
|
||||
image_latents = [image_latents.copy() for _ in range(batch_size)]
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_siglip_embeds(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
siglip_embeds = []
|
||||
for image in images:
|
||||
siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device)
|
||||
shape = siglip_inputs.spatial_shapes[0]
|
||||
hidden_state = self.siglip(**siglip_inputs).last_hidden_state
|
||||
B, N, C = hidden_state.shape
|
||||
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
||||
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
||||
siglip_embeds.append(hidden_state.to(dtype))
|
||||
|
||||
# siglip_embeds = [siglip_embeds] * batch_size
|
||||
siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)]
|
||||
|
||||
return siglip_embeds
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
num_condition_images = len(image) if image is not None else 0
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
|
||||
# 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2
|
||||
condition_images = []
|
||||
resized_images = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
if height is not None and width is not None:
|
||||
img = self.image_processor._resize_to_target_area(img, height * width)
|
||||
else:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
resized_images.append(img)
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
|
||||
if len(condition_images) > 0:
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
else:
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
condition_latents = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents]
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents]
|
||||
|
||||
condition_siglip_embeds = self.prepare_siglip_embeds(
|
||||
images=resized_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds]
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds]
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds]
|
||||
negative_condition_siglip_embeds = [
|
||||
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds
|
||||
]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
condition_latents_model_input = condition_latents + negative_condition_latents
|
||||
condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
condition_latents_model_input = condition_latents
|
||||
condition_siglip_embeds_model_input = condition_siglip_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
# Combine condition latents with target latent
|
||||
current_batch_size = len(latent_model_input_list)
|
||||
x_combined = [
|
||||
condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size)
|
||||
]
|
||||
# Create noise mask: 0 for condition images (clean), 1 for target image (noisy)
|
||||
image_noise_mask = [
|
||||
[0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size)
|
||||
]
|
||||
|
||||
model_out_list = self.transformer(
|
||||
x=x_combined,
|
||||
t=timestep_model_input,
|
||||
cap_feats=prompt_embeds_model_input,
|
||||
siglip_feats=condition_siglip_embeds_model_input,
|
||||
image_noise_mask=image_noise_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -217,6 +217,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -350,7 +352,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
if self.config.use_flow_sigmas:
|
||||
sigmas = sigmas / (sigmas + 1)
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
else:
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
|
||||
@@ -1777,6 +1777,21 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ZImageControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ZImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -767,6 +767,21 @@ class ConsisIDPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -2297,6 +2312,21 @@ class QwenImageInpaintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageLayeredPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -3842,6 +3872,36 @@ class WuerstchenPriorPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -3857,6 +3917,21 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageOmniPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._dtype = torch.float32
|
||||
self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False)
|
||||
|
||||
def check_text_safety(self, prompt: str) -> bool:
|
||||
return True
|
||||
@@ -35,13 +35,14 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
|
||||
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
|
||||
return frames
|
||||
|
||||
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
|
||||
self._dtype = dtype
|
||||
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None):
|
||||
module = super().to(device=device, dtype=dtype)
|
||||
return module
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return None
|
||||
return self._device_tracker.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._dtype
|
||||
return self._device_tracker.dtype
|
||||
|
||||
337
tests/pipelines/cosmos/test_cosmos2_5_predict.py
Normal file
337
tests/pipelines/cosmos/test_cosmos2_5_predict.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
CosmosTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from .cosmos_guardrail import DummyCosmosSafetyChecker
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(*args, **kwargs):
|
||||
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
|
||||
safety_checker = DummyCosmosSafetyChecker()
|
||||
device_map = kwargs.get("device_map", "cpu")
|
||||
torch_dtype = kwargs.get("torch_dtype")
|
||||
if device_map is not None or torch_dtype is not None:
|
||||
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
|
||||
kwargs["safety_checker"] = safety_checker
|
||||
return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Cosmos2_5_PredictBaseWrapper
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = CosmosTransformer3DModel(
|
||||
in_channels=16 + 1,
|
||||
out_channels=16,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
num_layers=2,
|
||||
mlp_ratio=2,
|
||||
text_embed_dim=32,
|
||||
adaln_lora_dim=4,
|
||||
max_size=(4, 32, 32),
|
||||
patch_size=(1, 2, 2),
|
||||
rope_scale=(2.0, 1.0, 1.0),
|
||||
concat_padding_mask=True,
|
||||
extra_pos_embed_type="learnable",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 16,
|
||||
},
|
||||
hidden_size=16,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": DummyCosmosSafetyChecker(),
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 3,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not getattr(self, "test_attention_slicing", True):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
self.pipeline_class._optional_components.remove("safety_checker")
|
||||
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||
self.pipeline_class._optional_components.append("safety_checker")
|
||||
|
||||
def test_serialization_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
model_components = [
|
||||
component_name
|
||||
for component_name, component in pipe.components.items()
|
||||
if isinstance(component, torch.nn.Module)
|
||||
]
|
||||
model_components.remove("safety_checker")
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
|
||||
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
for subfolder in os.listdir(tmpdir):
|
||||
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||
folder_path = os.path.join(tmpdir, subfolder)
|
||||
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
|
||||
def test_torch_dtype_dict(self):
|
||||
components = self.get_dummy_components()
|
||||
if not components:
|
||||
self.skipTest("No dummy components defined.")
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(
|
||||
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
|
||||
)
|
||||
|
||||
for name, component in loaded_pipe.components.items():
|
||||
if name == "safety_checker":
|
||||
continue
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
|
||||
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
|
||||
self.assertEqual(
|
||||
component.dtype,
|
||||
expected_dtype,
|
||||
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
|
||||
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
|
||||
"too large and slow to run on CI."
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
@@ -35,6 +35,7 @@ from diffusers.models.attention_processor import Attention
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
|
||||
from ...testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
backend_synchronize,
|
||||
enable_full_determinism,
|
||||
@@ -497,8 +498,23 @@ class TorchAoTest(unittest.TestCase):
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
model_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
expected_memory_saving_ratio = 2.0
|
||||
|
||||
expected_memory_saving_ratios = Expectations(
|
||||
{
|
||||
# XPU: For this tiny model, per-tensor overheads (alignment, fragmentation, metadata) become visible.
|
||||
# While XPU doesn't have the large fixed cuBLAS workspace of A100, these small overheads prevent reaching the ideal 2.0 ratio.
|
||||
# Observed ~1.27x (158k vs 124k) for model size.
|
||||
# The runtime memory overhead is ~88k for both bf16 and int8wo. Adding this to model size: (158k+88k)/(124k+88k) ≈ 1.15.
|
||||
("xpu", None): 1.15,
|
||||
# On Ampere, the cuBLAS kernels used for matrix multiplication often allocate a fixed-size workspace.
|
||||
# Since the tiny-flux model weights are likely smaller than or comparable to this workspace, the total memory is dominated by the workspace.
|
||||
("cuda", 8): 1.02,
|
||||
# On Hopper, TorchAO utilizes newer, highly optimized kernels (via Triton or CUTLASS 3.x) that are designed to be workspace-free or use negligible extra memory.
|
||||
# Additionally, Triton kernels often handle unaligned memory better, avoiding the padding overhead seen on other backends for tiny tensors.
|
||||
# This allows it to achieve the near-ideal 2.0x compression ratio.
|
||||
("cuda", 9): 2.0,
|
||||
}
|
||||
)
|
||||
expected_memory_saving_ratio = expected_memory_saving_ratios.get_expectation()
|
||||
inputs = self.get_dummy_tensor_inputs(device=torch_device)
|
||||
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
|
||||
@@ -399,3 +399,32 @@ class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
|
||||
|
||||
def test_exponential_sigmas(self):
|
||||
self.check_over_configs(use_exponential_sigmas=True)
|
||||
|
||||
def test_flow_and_karras_sigmas(self):
|
||||
self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True)
|
||||
|
||||
def test_flow_and_karras_sigmas_values(self):
|
||||
num_train_timesteps = 1000
|
||||
num_inference_steps = 5
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
sigma_min=0.01,
|
||||
sigma_max=200.0,
|
||||
use_flow_sigmas=True,
|
||||
use_karras_sigmas=True,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
)
|
||||
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
|
||||
|
||||
expected_sigmas = [
|
||||
0.9950248599052429,
|
||||
0.9787454605102539,
|
||||
0.8774884343147278,
|
||||
0.3604971766471863,
|
||||
0.009900986216962337,
|
||||
0.0, # 0 appended as default
|
||||
]
|
||||
expected_sigmas = torch.tensor(expected_sigmas)
|
||||
expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64)
|
||||
expected_timesteps = expected_timesteps[0:-1]
|
||||
self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas))
|
||||
self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps))
|
||||
|
||||
Reference in New Issue
Block a user