Compare commits

..

3 Commits

Author SHA1 Message Date
yiyixuxu
19e2ce1b2d refactor qwen modular 2025-12-22 01:02:40 +01:00
yiyixuxu
a1af845169 add conditoinal pipeline 2025-12-22 01:01:16 +01:00
yiyixuxu
3a1ba1a0e2 3 files 2025-12-20 00:27:54 +01:00
27 changed files with 1845 additions and 3248 deletions

View File

@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)

View File

@@ -21,8 +21,8 @@ from transformers import (
BertModel,
BertTokenizer,
CLIPImageProcessor,
MT5Tokenizer,
T5EncoderModel,
T5Tokenizer,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=T5Tokenizer,
tokenizer_2=MT5Tokenizer,
):
super().__init__()

View File

@@ -29,52 +29,13 @@ hf download nvidia/Cosmos-Predict2.5-2B
Convert checkpoint
```bash
# pre-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
--save_pipeline
# post-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
--save_pipeline
```
## 14B
```bash
hf download nvidia/Cosmos-Predict2.5-14B
```
```bash
# pre-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-14B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
--save_pipeline
# post-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-14B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
--output_path converted/cosmos-p2.5-base-2b \
--save_pipeline
```
@@ -337,25 +298,6 @@ TRANSFORMER_CONFIGS = {
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
"Cosmos-2.5-Predict-Base-14B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}
VAE_KEYS_RENAME_DICT = {

View File

@@ -675,7 +675,6 @@ else:
"ZImageControlNetInpaintPipeline",
"ZImageControlNetPipeline",
"ZImageImg2ImgPipeline",
"ZImageOmniPipeline",
"ZImagePipeline",
]
)
@@ -1387,7 +1386,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)

View File

@@ -25,7 +25,6 @@ if is_torch_available():
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .magnitude_aware_guidance import MagnitudeAwareGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance

View File

@@ -1,159 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class MagnitudeAwareGuidance(BaseGuidance):
"""
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
Args:
guidance_scale (`float`, defaults to `10.0`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
alpha (`float`, defaults to `8.0`):
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
guidance scale when the magnitude of the guidance update is large.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 10.0,
alpha: float = 8.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.alpha = alpha
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_mambo_g_enabled():
pred = pred_cond
else:
pred = mambo_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.alpha,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_mambo_g_enabled():
num_conditions += 1
return num_conditions
def _is_mambo_g_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def mambo_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
alpha: float = 8.0,
use_original_formulation: bool = False,
):
dim = list(range(1, len(pred_cond.shape)))
diff = pred_cond - pred_uncond
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
guidance_scale_final = (
guidance_scale * torch.exp(-alpha * ratio)
if use_original_formulation
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
)
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale_final * diff
return pred

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Literal, Optional, Tuple
from typing import List, Literal, Optional
import torch
import torch.nn as nn
@@ -170,21 +170,6 @@ class FeedForward(nn.Module):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
@maybe_allow_in_graph
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
class ZImageTransformerBlock(nn.Module):
@@ -235,37 +220,12 @@ class ZImageTransformerBlock(nn.Module):
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
):
if self.modulation:
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
@@ -533,93 +493,112 @@ class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
def patchify_and_embed(
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
):
"""Patchify for basic mode: single image per batch item."""
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
for image, cap_feat in zip(all_image, all_cap_feats):
# Caption
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
)
all_cap_out.append(cap_out)
all_cap_pos_ids.append(cap_pos_ids)
all_cap_pad_mask.append(cap_pad_mask)
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
all_cap_pos_ids = []
all_cap_pad_mask = []
all_cap_feats_out = []
# Image
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
### Process Caption
cap_ori_len = len(cap_feat)
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
# padded position ids
cap_padded_pos_ids = self.create_coordinate_grid(
size=(cap_ori_len + cap_padding_len, 1, 1),
start=(1, 0, 0),
device=device,
).flatten(0, 2)
all_cap_pos_ids.append(cap_padded_pos_ids)
# pad mask
cap_pad_mask = torch.cat(
[
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
all_img_out.append(img_out)
all_img_size.append(size)
all_img_pos_ids.append(img_pos_ids)
all_img_pad_mask.append(img_pad_mask)
all_cap_pad_mask.append(
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
)
# padded feature
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
all_cap_feats_out.append(cap_padded_feat)
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padded_pos_ids = torch.cat(
[
image_ori_pos_ids,
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(image_padding_len, 1),
],
dim=0,
)
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
# pad mask
image_pad_mask = torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
all_image_pad_mask.append(
image_pad_mask
if image_padding_len > 0
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
)
# padded feature
image_padded_feat = torch.cat(
[image, image[-1:].repeat(image_padding_len, 1)],
dim=0,
)
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
return (
all_img_out,
all_cap_out,
all_img_size,
all_img_pos_ids,
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_cap_pos_ids,
all_img_pad_mask,
all_image_pad_mask,
all_cap_pad_mask,
)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
@@ -32,7 +32,6 @@ from ..modeling_outputs import Transformer2DModelOutput
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
X_PAD_DIM = 64
class TimestepEmbedder(nn.Module):
@@ -153,20 +152,6 @@ class ZSingleStreamAttnProcessor:
return output
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
@@ -230,37 +215,12 @@ class ZImageTransformerBlock(nn.Module):
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
):
if self.modulation:
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
@@ -292,21 +252,9 @@ class FinalLayer(nn.Module):
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
else:
# Original global modulation
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
scale = 1.0 + self.adaLN_modulation(c)
scale = scale.unsqueeze(1)
x = self.norm_final(x) * scale
def forward(self, x, c):
scale = 1.0 + self.adaLN_modulation(c)
x = self.norm_final(x) * scale.unsqueeze(1)
x = self.linear(x)
return x
@@ -377,7 +325,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[32, 48, 48],
@@ -439,31 +386,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
# Optional SigLIP components (for Omni variant)
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
)
self.siglip_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
2000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
@@ -480,561 +402,259 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(
self,
x: List[torch.Tensor],
size: List[Tuple],
patch_size,
f_patch_size,
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
) -> List[torch.Tensor]:
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
pH = pW = patch_size
pF = f_patch_size
bsz = len(x)
assert len(size) == bsz
if x_pos_offsets is not None:
# Omni: extract target image from unified sequence (cond_images + target)
result = []
for i in range(bsz):
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
cu_len = 0
x_item = None
for j in range(len(size[i])):
if size[i][j] is None:
ori_len = 0
pad_len = SEQ_MULTI_OF
cu_len += pad_len + ori_len
else:
F, H, W = size[i][j]
ori_len = (F // pF) * (H // pH) * (W // pW)
pad_len = (-ori_len) % SEQ_MULTI_OF
x_item = (
unified_x[cu_len : cu_len + ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
cu_len += ori_len + pad_len
result.append(x_item) # Return only the last (target) image
return result
else:
# Original mode: simple unpatchify
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
def patchify_and_embed(
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
):
"""Patchify for basic mode: single image per batch item."""
device = all_image[0].device
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
for image, cap_feat in zip(all_image, all_cap_feats):
# Caption
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
)
all_cap_out.append(cap_out)
all_cap_pos_ids.append(cap_pos_ids)
all_cap_pad_mask.append(cap_pad_mask)
# Image
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
)
all_img_out.append(img_out)
all_img_size.append(size)
all_img_pos_ids.append(img_pos_ids)
all_img_pad_mask.append(img_pad_mask)
return (
all_img_out,
all_cap_out,
all_img_size,
all_img_pos_ids,
all_cap_pos_ids,
all_img_pad_mask,
all_cap_pad_mask,
)
def patchify_and_embed_omni(
self,
all_x: List[List[torch.Tensor]],
all_cap_feats: List[List[torch.Tensor]],
all_siglip_feats: List[List[torch.Tensor]],
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
images_noise_mask: List[List[int]],
):
"""Patchify for omni mode: multiple images per batch item with noise masks."""
bsz = len(all_x)
device = all_x[0][-1].device
dtype = all_x[0][-1].dtype
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
all_cap_pos_ids = []
all_cap_pad_mask = []
all_cap_feats_out = []
for i in range(bsz):
num_images = len(all_x[i])
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
cap_end_pos = []
cap_cu_len = 1
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
### Process Caption
cap_ori_len = len(cap_feat)
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
# padded position ids
cap_padded_pos_ids = self.create_coordinate_grid(
size=(cap_ori_len + cap_padding_len, 1, 1),
start=(1, 0, 0),
device=device,
).flatten(0, 2)
all_cap_pos_ids.append(cap_padded_pos_ids)
# pad mask
cap_pad_mask = torch.cat(
[
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
all_cap_pad_mask.append(
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
)
# Process captions
for j, cap_item in enumerate(all_cap_feats[i]):
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
cap_item,
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
(cap_cu_len, 0, 0),
device,
noise_val,
)
cap_feats_list.append(cap_out)
cap_pos_list.append(cap_pos)
cap_mask_list.append(cap_mask)
cap_lens.append(cap_len)
cap_noise.extend(cap_nm)
cap_cu_len += len(cap_item)
cap_end_pos.append(cap_cu_len)
cap_cu_len += 2 # for image vae and siglip tokens
# padded feature
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
all_cap_feats_out.append(cap_padded_feat)
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
all_cap_len.append(cap_lens)
all_cap_noise_mask.append(cap_noise)
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
# Process images
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
for j, x_item in enumerate(all_x[i]):
noise_val = images_noise_mask[i][j]
if x_item is not None:
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
)
x_size.append(size)
else:
x_len = SEQ_MULTI_OF
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
x_nm = [noise_val] * x_len
x_size.append(None)
x_feats_list.append(x_out)
x_pos_list.append(x_pos)
x_mask_list.append(x_mask)
x_lens.append(x_len)
x_noise.extend(x_nm)
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
all_x_out.append(torch.cat(x_feats_list, dim=0))
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
all_x_size.append(x_size)
all_x_len.append(x_lens)
all_x_noise_mask.append(x_noise)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
# Process siglip
if all_siglip_feats[i] is None:
all_sig_len.append([0] * num_images)
all_sig_out.append(None)
else:
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
for j, sig_item in enumerate(all_siglip_feats[i]):
noise_val = images_noise_mask[i][j]
if sig_item is not None:
sig_H, sig_W, sig_C = sig_item.size()
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
)
# Scale position IDs to match x resolution
if x_size[j] is not None:
sig_pos = sig_pos.float()
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
sig_pos = sig_pos.to(torch.int32)
else:
sig_len = SEQ_MULTI_OF
sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device)
sig_pos = (
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
)
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
sig_nm = [noise_val] * sig_len
sig_feats_list.append(sig_out)
sig_pos_list.append(sig_pos)
sig_mask_list.append(sig_mask)
sig_lens.append(sig_len)
sig_noise.extend(sig_nm)
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
all_sig_len.append(sig_lens)
all_sig_noise_mask.append(sig_noise)
# Compute x position offsets
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padded_pos_ids = torch.cat(
[
image_ori_pos_ids,
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(image_padding_len, 1),
],
dim=0,
)
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
# pad mask
image_pad_mask = torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
all_image_pad_mask.append(
image_pad_mask
if image_padding_len > 0
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
)
# padded feature
image_padded_feat = torch.cat(
[image, image[-1:].repeat(image_padding_len, 1)],
dim=0,
)
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
return (
all_x_out,
all_cap_out,
all_sig_out,
all_x_size,
all_x_pos_ids,
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_cap_pos_ids,
all_sig_pos_ids,
all_x_pad_mask,
all_image_pad_mask,
all_cap_pad_mask,
all_sig_pad_mask,
all_x_pos_offsets,
all_x_noise_mask,
all_cap_noise_mask,
all_sig_noise_mask,
)
def _prepare_sequence(
self,
feats: List[torch.Tensor],
pos_ids: List[torch.Tensor],
inner_pad_mask: List[torch.Tensor],
pad_token: torch.nn.Parameter,
noise_mask: Optional[List[List[int]]] = None,
device: torch.device = None,
):
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
item_seqlens = [len(f) for f in feats]
max_seqlen = max(item_seqlens)
bsz = len(feats)
# Pad token
feats_cat = torch.cat(feats, dim=0)
feats_cat[torch.cat(inner_pad_mask)] = pad_token
feats = list(feats_cat.split(item_seqlens, dim=0))
# RoPE
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
# Pad to batch
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if noise_mask is not None:
noise_mask_tensor = pad_sequence(
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
batch_first=True,
padding_value=0,
)[:, : feats.shape[1]]
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
def _build_unified_sequence(
self,
x: torch.Tensor,
x_freqs: torch.Tensor,
x_seqlens: List[int],
x_noise_mask: Optional[List[List[int]]],
cap: torch.Tensor,
cap_freqs: torch.Tensor,
cap_seqlens: List[int],
cap_noise_mask: Optional[List[List[int]]],
siglip: Optional[torch.Tensor],
siglip_freqs: Optional[torch.Tensor],
siglip_seqlens: Optional[List[int]],
siglip_noise_mask: Optional[List[List[int]]],
omni_mode: bool,
device: torch.device,
):
"""Build unified sequence: x, cap, and optionally siglip.
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
"""
bsz = len(x_seqlens)
unified = []
unified_freqs = []
unified_noise_mask = []
for i in range(bsz):
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
if omni_mode:
# Omni: [cap, x, siglip]
if siglip is not None and siglip_seqlens is not None:
sig_len = siglip_seqlens[i]
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
unified_freqs.append(
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
)
unified_noise_mask.append(
torch.tensor(
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
)
)
else:
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
unified_noise_mask.append(
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
)
else:
# Basic: [x, cap]
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
# Compute unified seqlens
if omni_mode:
if siglip is not None and siglip_seqlens is not None:
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
max_seqlen = max(unified_seqlens)
# Pad to batch
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if omni_mode:
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
:, : unified.shape[1]
]
return unified, unified_freqs, attn_mask, noise_mask_tensor
def forward(
self,
x: Union[List[torch.Tensor], List[List[torch.Tensor]]],
x: List[torch.Tensor],
t,
cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]],
return_dict: bool = True,
cap_feats: List[torch.Tensor],
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
siglip_feats: Optional[List[List[torch.Tensor]]] = None,
image_noise_mask: Optional[List[List[int]]] = None,
patch_size: int = 2,
f_patch_size: int = 1,
patch_size=2,
f_patch_size=1,
return_dict: bool = True,
):
"""
Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine
-> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify
"""
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
omni_mode = isinstance(x[0], list)
device = x[0][-1].device if omni_mode else x[0].device
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
if omni_mode:
# Dual embeddings: noisy (t) and clean (t=1)
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
adaln_input = None
else:
# Single embedding for all tokens
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
t_noisy = t_clean = None
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
# Patchify
if omni_mode:
(
x,
cap_feats,
siglip_feats,
x_size,
x_pos_ids,
cap_pos_ids,
siglip_pos_ids,
x_pad_mask,
cap_pad_mask,
siglip_pad_mask,
x_pos_offsets,
x_noise_mask,
cap_noise_mask,
siglip_noise_mask,
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
else:
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_pad_mask,
cap_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
# X embed & refine
x_seqlens = [len(xi) for xi in x]
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
)
for layer in self.noise_refiner:
x = (
self._gradient_checkpointing_func(
layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean
)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean)
)
# Cap embed & refine
cap_seqlens = [len(ci) for ci in cap_feats]
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
)
for layer in self.context_refiner:
cap_feats = (
self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(cap_feats, cap_mask, cap_freqs)
)
# Siglip embed & refine
siglip_seqlens = siglip_freqs = None
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
siglip_seqlens = [len(si) for si in siglip_feats]
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
list(siglip_feats.split(siglip_seqlens, dim=0)),
siglip_pos_ids,
siglip_pad_mask,
self.siglip_pad_token,
None,
device,
)
for layer in self.siglip_refiner:
siglip_feats = (
self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(siglip_feats, siglip_mask, siglip_freqs)
)
# Unified sequence
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
(
x,
x_freqs,
x_seqlens,
x_noise_mask,
cap_feats,
cap_freqs,
cap_seqlens,
cap_noise_mask,
siglip_feats,
siglip_freqs,
siglip_seqlens,
siglip_noise_mask,
omni_mode,
device,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# x embed & refine
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
# Match t_embedder output dtype to x for layerwise casting compatibility
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
else:
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
)
# Main transformer layers
for layer_idx, layer in enumerate(self.layers):
unified = (
self._gradient_checkpointing_func(
layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
# unified
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer_idx, layer in enumerate(self.layers):
unified = self._gradient_checkpointing_func(
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean)
)
if controlnet_block_samples is not None and layer_idx in controlnet_block_samples:
unified = unified + controlnet_block_samples[layer_idx]
if controlnet_block_samples is not None:
if layer_idx in controlnet_block_samples:
unified = unified + controlnet_block_samples[layer_idx]
else:
for layer_idx, layer in enumerate(self.layers):
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
if controlnet_block_samples is not None:
if layer_idx in controlnet_block_samples:
unified = unified + controlnet_block_samples[layer_idx]
unified = (
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
)
if omni_mode
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
# Unpatchify
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
if not return_dict:
return (x,)
return (x,) if not return_dict else Transformer2DModelOutput(sample=x)
return Transformer2DModelOutput(sample=x)

View File

@@ -231,7 +231,7 @@ class BlockState:
class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks,
LoopSequentialPipelineBlocks
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
@@ -527,9 +527,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
)
class AutoPipelineBlocks(ModularPipelineBlocks):
class ConditionalPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the inputs.
A Pipeline Blocks that conditionally selects a block to run based on the inputs.
Subclasses must implement the `select_block` method to define the logic for selecting the block.
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -539,12 +540,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
Attributes:
block_classes: List of block classes to be used
block_names: List of prefixes for each block
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
"""
block_classes = []
block_names = []
block_trigger_inputs = []
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
def __init__(self):
sub_blocks = InsertableDict()
@@ -554,26 +556,15 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
if not (len(self.block_classes) == len(self.block_names)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same."
)
default_blocks = [t for t in self.block_trigger_inputs if t is None]
# can only have 1 or 0 default block, and has to put in the last
# the order of blocks matters here because the first block with matching trigger will be dispatched
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
# as long as mask is provided, it is inpaint; if only image is provided, it is img2img
if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
if self.default_block_name is not None and self.default_block_name not in self.block_names:
raise ValueError(
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
"in block_trigger_inputs."
f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}"
)
# Map trigger inputs to block objects
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
@property
def model_name(self):
return next(iter(self.sub_blocks.values())).model_name
@@ -602,8 +593,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
@property
def required_inputs(self) -> List[str]:
if None not in self.block_trigger_inputs:
# no default block means this conditional block can be skipped entirely
if self.default_block_name is None:
return []
first_block = next(iter(self.sub_blocks.values()))
required_by_all = set(getattr(first_block, "required_inputs", set()))
@@ -614,7 +608,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return list(required_by_all)
# YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
@@ -639,36 +633,9 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
combined_outputs = self.combine_outputs(*named_outputs)
return combined_outputs
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
# Find default block first (if any)
block = self.trigger_to_block_map.get(None)
for input_name in self.block_trigger_inputs:
if input_name is not None and state.get(input_name) is not None:
block = self.trigger_to_block_map[input_name]
break
if block is None:
logger.info(f"skipping auto block: {self.__class__.__name__}")
return pipeline, state
try:
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
return block(pipeline, state)
except Exception as e:
error_msg = (
f"\nError in block: {block.__class__.__name__}\n"
f"Error details: {str(e)}\n"
f"Traceback:\n{traceback.format_exc()}"
)
logger.error(error_msg)
raise
def _get_trigger_inputs(self):
def _get_trigger_inputs(self) -> set:
"""
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
block_trigger_inputs values
Returns a set of all unique trigger input values found in this block and nested blocks.
"""
def fn_recursive_get_trigger(blocks):
@@ -676,9 +643,8 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
if blocks is not None:
for name, block in blocks.items():
# Check if current block has trigger inputs(i.e. auto block)
# Check if current block has block_trigger_inputs
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
# Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has sub_blocks, recursively check them
@@ -688,15 +654,58 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return trigger_values
trigger_inputs = set(self.block_trigger_inputs)
trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
# Start with this block's block_trigger_inputs
all_triggers = set(t for t in self.block_trigger_inputs if t is not None)
# Add nested triggers
all_triggers.update(fn_recursive_get_trigger(self.sub_blocks))
return trigger_inputs
return all_triggers
@property
def trigger_inputs(self):
"""All trigger inputs including from nested blocks."""
return self._get_trigger_inputs()
def select_block(self, **kwargs) -> Optional[str]:
"""
Select the block to run based on the trigger inputs.
Subclasses must implement this method to define the logic for selecting the block.
Args:
**kwargs: Trigger input names and their values from the state.
Returns:
Optional[str]: The name of the block to run, or None to use default/skip.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.")
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None}
block_name = self.select_block(**trigger_kwargs)
if block_name is None:
block_name = self.default_block_name
if block_name is None:
logger.info(f"skipping conditional block: {self.__class__.__name__}")
return pipeline, state
block = self.sub_blocks[block_name]
try:
logger.info(f"Running block: {block.__class__.__name__}")
return block(pipeline, state)
except Exception as e:
error_msg = (
f"\nError in block: {block.__class__.__name__}\n"
f"Error details: {str(e)}\n"
f"Traceback:\n{traceback.format_exc()}"
)
logger.error(error_msg)
raise
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
@@ -708,7 +717,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -729,31 +738,20 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
expected_configs = getattr(self, "expected_configs", [])
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
# Blocks section - moved to the end with simplified format
# Blocks section
blocks_str = " Sub-Blocks:\n"
for i, (name, block) in enumerate(self.sub_blocks.items()):
# Get trigger input for this block
trigger = None
if hasattr(self, "block_to_trigger_map"):
trigger = self.block_to_trigger_map.get(name)
# Format the trigger info
if trigger is None:
trigger_str = "[default]"
elif isinstance(trigger, (list, tuple)):
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
else:
trigger_str = f"[trigger: {trigger}]"
# For AutoPipelineBlocks, add bullet points
blocks_str += f"{name} {trigger_str} ({block.__class__.__name__})\n"
if name == self.default_block_name:
addtional_str = " [default]"
else:
# For SequentialPipelineBlocks, show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
addtional_str = ""
blocks_str += f" {name}{addtional_str} ({block.__class__.__name__})\n"
# Add block description
desc_lines = block.description.split("\n")
indented_desc = desc_lines[0]
if len(desc_lines) > 1:
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
block_desc_lines = block.description.split("\n")
indented_desc = block_desc_lines[0]
if len(block_desc_lines) > 1:
indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n"
# Build the representation with conditional sections
@@ -784,6 +782,35 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
)
class AutoPipelineBlocks(ConditionalPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
"""
def __init__(self):
super().__init__()
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
)
@property
def default_block_name(self) -> Optional[str]:
"""Derive default_block_name from block_trigger_inputs (None entry)."""
if None in self.block_trigger_inputs:
idx = self.block_trigger_inputs.index(None)
return self.block_names[idx]
return None
def select_block(self, **kwargs) -> Optional[str]:
"""Select block based on which trigger input is present (not None)."""
for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names):
if trigger_input is not None and kwargs.get(trigger_input) is not None:
return block_name
return None
class SequentialPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
@@ -885,7 +912,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# Only add outputs if the block cannot be skipped
should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None:
# ConditionalPipelineBlocks without default can be skipped
should_add_outputs = False
if should_add_outputs:
@@ -948,8 +976,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def _get_trigger_inputs(self):
"""
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
block_trigger_inputs values
Returns a set of all unique trigger input values found in the blocks.
"""
def fn_recursive_get_trigger(blocks):
@@ -957,9 +984,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
if blocks is not None:
for name, block in blocks.items():
# Check if current block has trigger inputs(i.e. auto block)
# Check if current block has block_trigger_inputs (ConditionalPipelineBlocks)
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
# Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has sub_blocks, recursively check them
@@ -975,82 +1001,85 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def trigger_inputs(self):
return self._get_trigger_inputs()
def _traverse_trigger_blocks(self, trigger_inputs):
# Convert trigger_inputs to a set for easier manipulation
active_triggers = set(trigger_inputs)
def _traverse_trigger_blocks(self, active_inputs):
"""
Traverse blocks and select which ones would run given the active inputs.
def fn_recursive_traverse(block, block_name, active_triggers):
Args:
active_inputs: Dict of input names to values that are "present"
Returns:
OrderedDict of block_name -> block that would execute
"""
def fn_recursive_traverse(block, block_name, active_inputs):
result_blocks = OrderedDict()
# sequential(include loopsequential) or PipelineBlock
if not hasattr(block, "block_trigger_inputs"):
if block.sub_blocks:
# sequential or LoopSequentialPipelineBlocks (keep traversing)
for sub_block_name, sub_block in block.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
if isinstance(block, ConditionalPipelineBlocks):
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
selected_block_name = block.select_block(**trigger_kwargs)
if selected_block_name is None:
selected_block_name = block.default_block_name
if selected_block_name is None:
return result_blocks
selected_block = block.sub_blocks[selected_block_name]
if selected_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
else:
# PipelineBlock
result_blocks[block_name] = block
# Add this block's output names to active triggers if defined
if hasattr(block, "outputs"):
active_triggers.update(out.name for out in block.outputs)
result_blocks[block_name] = selected_block
if hasattr(selected_block, "outputs"):
for out in selected_block.outputs:
active_inputs[out.name] = True
return result_blocks
# auto
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
if block.sub_blocks:
for sub_block_name, sub_block in block.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
else:
# Find first block_trigger_input that matches any value in our active_triggers
this_block = None
for trigger_input in block.block_trigger_inputs:
if trigger_input is not None and trigger_input in active_triggers:
this_block = block.trigger_to_block_map[trigger_input]
break
# If no matches found, try to get the default (None) block
if this_block is None and None in block.block_trigger_inputs:
this_block = block.trigger_to_block_map[None]
if this_block is not None:
# sequential/auto (keep traversing)
if this_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
else:
# PipelineBlock
result_blocks[block_name] = this_block
# Add this block's output names to active triggers if defined
# YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
if hasattr(this_block, "outputs"):
active_triggers.update(out.name for out in this_block.outputs)
result_blocks[block_name] = block
if hasattr(block, "outputs"):
for out in block.outputs:
active_inputs[out.name] = True
return result_blocks
all_blocks = OrderedDict()
for block_name, block in self.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(blocks_to_update)
return all_blocks
def get_execution_blocks(self, *trigger_inputs):
trigger_inputs_all = self.trigger_inputs
def get_execution_blocks(self, **kwargs):
"""
Get the blocks that would execute given the specified inputs.
if trigger_inputs is not None:
if not isinstance(trigger_inputs, (list, tuple, set)):
trigger_inputs = [trigger_inputs]
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
if invalid_inputs:
logger.warning(
f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
)
trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Pass any inputs that would be non-None at runtime.
if trigger_inputs is None:
if None in trigger_inputs_all:
trigger_inputs = [None]
else:
trigger_inputs = [trigger_inputs_all[0]]
blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
Returns:
SequentialPipelineBlocks containing only the blocks that would execute
Example:
# Get blocks for inpainting workflow
blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask, image=image)
# Get blocks for text2image workflow
blocks = pipeline.get_execution_blocks(prompt="a cat")
"""
# Filter out None values
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
def __repr__(self):
@@ -1067,7 +1096,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
# Get first trigger input as example
example_input = next(t for t in self.trigger_inputs if t is not None)
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -1091,22 +1120,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# Blocks section - moved to the end with simplified format
blocks_str = " Sub-Blocks:\n"
for i, (name, block) in enumerate(self.sub_blocks.items()):
# Get trigger input for this block
trigger = None
if hasattr(self, "block_to_trigger_map"):
trigger = self.block_to_trigger_map.get(name)
# Format the trigger info
if trigger is None:
trigger_str = "[default]"
elif isinstance(trigger, (list, tuple)):
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
else:
trigger_str = f"[trigger: {trigger}]"
# For AutoPipelineBlocks, add bullet points
blocks_str += f"{name} {trigger_str} ({block.__class__.__name__})\n"
else:
# For SequentialPipelineBlocks, show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# Add block description
desc_lines = block.description.split("\n")
@@ -1230,15 +1246,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
if inp.name not in outputs and inp not in inputs:
inputs.append(inp)
# Only add outputs if the block cannot be skipped
should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
should_add_outputs = False
if should_add_outputs:
# Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
# Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
for input_param in inputs:
if input_param.name in self.required_inputs:
@@ -1295,6 +1305,14 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
# Validate that sub_blocks are only leaf blocks
for block_name, block in self.sub_blocks.items():
if block.sub_blocks:
raise ValueError(
f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). "
f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks."
)
@classmethod
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
"""

View File

@@ -21,21 +21,16 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
_import_structure["modular_blocks_qwenimage"] = [
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"EDIT_AUTO_BLOCKS",
"EDIT_BLOCKS",
"EDIT_INPAINT_BLOCKS",
"EDIT_PLUS_AUTO_BLOCKS",
"EDIT_PLUS_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"QwenImageAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_edit"] = [
"EDIT_AUTO_BLOCKS",
"QwenImageEditAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
"EDIT_PLUS_AUTO_BLOCKS",
"QwenImageEditPlusAutoBlocks",
]
_import_structure["modular_pipeline"] = [
@@ -51,23 +46,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import (
QwenImageTextEncoderStep,
)
from .modular_blocks import (
ALL_BLOCKS,
from .modular_blocks_qwenimage import (
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
EDIT_AUTO_BLOCKS,
EDIT_BLOCKS,
EDIT_INPAINT_BLOCKS,
EDIT_PLUS_AUTO_BLOCKS,
EDIT_PLUS_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
TEXT2IMAGE_BLOCKS,
QwenImageAutoBlocks,
)
from .modular_blocks_qwenimage_edit import (
EDIT_AUTO_BLOCKS,
QwenImageEditAutoBlocks,
)
from .modular_blocks_qwenimage_edit_plus import (
EDIT_PLUS_AUTO_BLOCKS,
QwenImageEditPlusAutoBlocks,
)
from .modular_pipeline import (
@@ -86,4 +74,4 @@ else:
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
setattr(sys.modules[__name__], name, value)

View File

@@ -639,19 +639,65 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
"""RoPE inputs step for Edit Plus that handles lists of image heights/widths."""
model_name = "qwenimage-edit-plus"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
"Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
"Should be placed after prepare_latents step."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="image_height", required=True, type_hint=List[int]),
InputParam(name="image_width", required=True, type_hint=List[int]),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the image latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_scale_factor = components.vae_scale_factor
# Edit Plus: image_height and image_width are lists
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
(1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
for img_height, img_width in zip(block_state.image_height, block_state.image_width)
],
]
] * block_state.batch_size

View File

@@ -244,18 +244,19 @@ def encode_vae_image(
class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
This block resizes an input image tensor and exposes the resized result under configurable input and output
names. Use this when you need to wire the resize step to different image fields (e.g., "image",
"control_image")
def __init__(
self,
input_name: str = "image",
output_name: str = "resized_image",
target_area: int = 1024 * 1024,
):
"""Create a configurable step for resizing images to the target area while maintaining the aspect ratio.
Args:
input_name (str, optional): Name of the image field to read from the
pipeline state. Defaults to "image".
output_name (str, optional): Name of the resized image field to write
back to the pipeline state. Defaults to "resized_image".
target_area (int, optional): Target area in pixels. Defaults to 1024*1024.
"""
if not isinstance(input_name, str) or not isinstance(output_name, str):
raise ValueError(
@@ -263,11 +264,12 @@ class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
)
self._image_input_name = input_name
self._resized_image_output_name = output_name
self._target_area = target_area
super().__init__()
@property
def description(self) -> str:
return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
return f"Image Resize step that resize the {self._image_input_name} to the target area {self._target_area} while maintaining the aspect ratio."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -320,48 +322,67 @@ class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class QwenImageEditPlusResizeDynamicStep(ModularPipelineBlocks):
"""Resize each image independently based on its own aspect ratio. For QwenImage Edit Plus."""
class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
model_name = "qwenimage"
model_name = "qwenimage-edit-plus"
def __init__(
self,
input_name: str = "image",
self,
input_name: str = "image",
output_name: str = "resized_image",
vae_image_output_name: str = "vae_image",
target_area: int = 1024 * 1024,
):
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
"""Create a step for resizing images to a target area.
This block resizes an input image or a list input images and exposes the resized result under configurable
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
"image", "control_image")
Each image is resized independently based on its own aspect ratio.
This is suitable for Edit Plus where multiple reference images can have different dimensions.
Args:
input_name (str, optional): Name of the image field to read from the
pipeline state. Defaults to "image".
output_name (str, optional): Name of the resized image field to write
back to the pipeline state. Defaults to "resized_image".
vae_image_output_name (str, optional): Name of the image field
to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
processes the input image(s) differently for the VL and the VAE.
input_name (str, optional): Name of the image field to read. Defaults to "image".
output_name (str, optional): Name of the resized image field to write. Defaults to "resized_image".
target_area (int, optional): Target area in pixels. Defaults to 1024*1024.
"""
if not isinstance(input_name, str) or not isinstance(output_name, str):
raise ValueError(
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
)
self.condition_image_size = 384 * 384
self._image_input_name = input_name
self._resized_image_output_name = output_name
self._vae_image_output_name = vae_image_output_name
self._target_area = target_area
super().__init__()
@property
def description(self) -> str:
return (
f"Image Resize step that resizes {self._image_input_name} to target area {self._target_area}.\n"
"Each image is resized independently based on its own aspect ratio."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_resize_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image(s) to resize"
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return super().intermediate_outputs + [
return [
OutputParam(
name=self._vae_image_output_name,
type_hint=List[PIL.Image.Image],
description="The images to be processed which will be further used by the VAE encoder.",
name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
),
]
@@ -374,26 +395,21 @@ class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
if not is_valid_image_imagelist(images):
raise ValueError(f"Images must be image or list of images but are {type(images)}")
if (
not isinstance(images, torch.Tensor)
and isinstance(images, PIL.Image.Image)
and not isinstance(images, list)
):
if is_valid_image(images):
images = [images]
# TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
condition_images = []
vae_images = []
for img in images:
image_width, image_height = img.size
condition_width, condition_height, _ = calculate_dimensions(
self.condition_image_size, image_width / image_height
# Resize each image independently based on its own aspect ratio
resized_images = []
for image in images:
image_width, image_height = image.size
calculated_width, calculated_height, _ = calculate_dimensions(
self._target_area, image_width / image_height
)
resized_images.append(
components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
)
condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
vae_images.append(img)
setattr(block_state, self._resized_image_output_name, condition_images)
setattr(block_state, self._vae_image_output_name, vae_images)
setattr(block_state, self._resized_image_output_name, resized_images)
self.set_block_state(state, block_state)
return components, state
@@ -647,8 +663,30 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
model_name = "qwenimage"
class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
"""Text encoder for QwenImage Edit Plus that handles multiple reference images."""
model_name = "qwenimage-edit-plus"
@property
def description(self) -> str:
return (
"Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together "
"to generate text embeddings for guiding image generation."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
ComponentSpec("processor", Qwen2VLProcessor),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
@@ -664,6 +702,60 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
InputParam(
name="resized_cond_image",
required=True,
type_hint=torch.Tensor,
description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The prompt embeddings",
),
OutputParam(
name="prompt_embeds_mask",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The encoder attention mask",
),
OutputParam(
name="negative_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The negative prompt embeddings",
),
OutputParam(
name="negative_prompt_embeds_mask",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The negative prompt embeddings mask",
),
]
@staticmethod
def check_inputs(prompt, negative_prompt):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if (
negative_prompt is not None
and not isinstance(negative_prompt, str)
and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
@@ -676,7 +768,7 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
components.text_encoder,
components.processor,
prompt=block_state.prompt,
image=block_state.resized_image,
image=block_state.resized_cond_image,
prompt_template_encode=components.config.prompt_template_encode,
img_template_encode=components.config.img_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
@@ -692,7 +784,7 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
components.text_encoder,
components.processor,
prompt=negative_prompt,
image=block_state.resized_image,
image=block_state.resized_cond_image,
prompt_template_encode=components.config.prompt_template_encode,
img_template_encode=components.config.img_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
@@ -846,60 +938,60 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks):
model_name = "qwenimage-edit-plus"
def __init__(self):
self.vae_image_size = 1024 * 1024
super().__init__()
@property
def description(self) -> str:
return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
return [InputParam("resized_image")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam(name="processed_image")]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
if block_state.vae_image is None and block_state.image is None:
raise ValueError("`vae_image` and `image` cannot be None at the same time")
vae_image_sizes = None
if block_state.vae_image is None:
image = block_state.image
self.check_inputs(
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
)
height = block_state.height or components.default_height
width = block_state.width or components.default_width
block_state.processed_image = components.image_processor.preprocess(
image=image, height=height, width=width
)
else:
# QwenImage Edit Plus can allow multiple input images with varied resolutions
processed_images = []
vae_image_sizes = []
for img in block_state.vae_image:
width, height = img.size
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
vae_image_sizes.append((vae_width, vae_height))
processed_images.append(
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
)
image = block_state.resized_image
is_image_list = isinstance(image, list)
if not is_image_list:
image = [image]
processed_images = []
for img in image:
img_width, img_height = img.size
processed_images.append(components.image_processor.preprocess(image=img, height=img_height, width=img_width))
block_state.processed_image = processed_images
if is_image_list:
block_state.processed_image = processed_images
block_state.vae_image_sizes = vae_image_sizes
else:
block_state.processed_image = processed_images[0]
self.set_block_state(state, block_state)
return components, state
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
"""VAE encoder that handles both single images and lists of images with varied resolutions."""
model_name = "qwenimage"
def __init__(
@@ -909,21 +1001,12 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
):
"""Initialize a VAE encoder step for converting images to latent representations.
Both the input and output names are configurable so this block can be configured to process to different image
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
Handles both single images and lists of images. When input is a list, outputs a list of latents.
When input is a single tensor, outputs a single latent tensor.
Args:
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
Examples: "processed_image" or "processed_control_image"
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
Examples: "image_latents" or "control_image_latents"
Examples:
# Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
# Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
input_name="processed_control_image", output_name="control_image_latents"
)
input_name (str, optional): Name of the input image tensor or list. Defaults to "processed_image".
output_name (str, optional): Name of the output latent tensor or list. Defaults to "image_latents".
"""
self._image_input_name = input_name
self._image_latents_output_name = output_name
@@ -931,17 +1014,18 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
@property
def description(self) -> str:
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
return (
f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
"Handles both single images and lists of images with varied resolutions."
)
@property
def expected_components(self) -> List[ComponentSpec]:
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
return components
return [ComponentSpec("vae", AutoencoderKLQwenImage)]
@property
def inputs(self) -> List[InputParam]:
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
return inputs
return [InputParam(self._image_input_name, required=True), InputParam("generator")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
@@ -949,7 +1033,7 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
OutputParam(
self._image_latents_output_name,
type_hint=torch.Tensor,
description="The latents representing the reference image",
description="The latents representing the reference image(s). Single tensor or list depending on input.",
)
]
@@ -961,47 +1045,11 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
dtype = components.vae.dtype
image = getattr(block_state, self._image_input_name)
is_image_list = isinstance(image, list)
if not is_image_list:
image = [image]
# Encode image into latents
image_latents = encode_vae_image(
image=image,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
)
setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)
return components, state
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
model_name = "qwenimage-edit-plus"
@property
def intermediate_outputs(self) -> List[OutputParam]:
# Each reference image latent can have varied resolutions hence we return this as a list.
return [
OutputParam(
self._image_latents_output_name,
type_hint=List[torch.Tensor],
description="The latents representing the reference image(s).",
)
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.vae.dtype
image = getattr(block_state, self._image_input_name)
# Encode image into latents
# Handle both single image and list of images
image_latents = []
for img in image:
image_latents.append(
@@ -1014,9 +1062,12 @@ class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
latent_channels=components.num_channels_latents,
)
)
if not is_image_list:
image_latents = image_latents[0]
setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)
return components, state

View File

@@ -222,36 +222,15 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
"""Input step for QwenImage: update height/width, expand batch, patchify."""
model_name = "qwenimage"
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to []. Examples: ["processed_mask_image"]
Examples:
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
# Configure to process multiple image latent inputs
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
)
"""
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
@@ -263,14 +242,12 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
" 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
@@ -279,11 +256,16 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
inputs = [
@@ -293,11 +275,9 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
InputParam(name="width"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
@@ -310,22 +290,16 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
]
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents
# 1. Calculate height/width from latents and update if not provided
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
@@ -335,7 +309,7 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
if not hasattr(block_state, "image_width"):
block_state.image_width = width
# 2. Patchify the image latent tensor
# 2. Patchify
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
# 3. Expand batch size
@@ -354,7 +328,6 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
@@ -368,63 +341,130 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
class QwenImageEditPlusInputsDynamicStep(ModularPipelineBlocks):
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
model_name = "qwenimage-edit-plus"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
summary_section = (
"Input processing step for Edit Plus that:\n"
" 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size\n"
" Height/width defaults to last image in the list."
)
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
OutputParam(name="image_height", type_hint=List[int], description="The heights of the image latents"),
OutputParam(name="image_width", type_hint=List[int], description="The widths of the image latents"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# Each image latent can have different size in QwenImage Edit Plus.
is_list = isinstance(image_latent_tensor, list)
if not is_list:
image_latent_tensor = [image_latent_tensor]
image_heights = []
image_widths = []
packed_image_latent_tensors = []
for img_latent_tensor in image_latent_tensor:
for i, img_latent_tensor in enumerate(image_latent_tensor):
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
image_heights.append(height)
image_widths.append(width)
# 2. Patchify the image latent tensor
# 2. Patchify
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
# 3. Expand batch size
img_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_name=f"{image_latent_input_name}[{i}]",
input_tensor=img_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
packed_image_latent_tensors.append(img_latent_tensor)
# Concatenate all packed latents along dim=1
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
# Output lists of heights/widths
block_state.image_height = image_heights
block_state.image_width = image_widths
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
# Default height/width from last image
block_state.height = block_state.height or image_heights[-1]
block_state.width = block_state.width or image_widths[-1]
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
@@ -436,8 +476,6 @@ class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
self.set_block_state(state, block_state)
return components, state
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,465 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks, ConditionalPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
QwenImageRoPEInputsStep,
QwenImageSetTimestepsStep,
QwenImageSetTimestepsWithStrengthStep,
)
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageInpaintProcessImagesOutputStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageControlNetDenoiseStep,
QwenImageDenoiseStep,
QwenImageInpaintControlNetDenoiseStep,
QwenImageInpaintDenoiseStep,
QwenImageLoopBeforeDenoiserControlNet,
)
from .encoders import (
QwenImageControlNetVaeEncoderStep,
QwenImageInpaintProcessImagesInputStep,
QwenImageProcessImagesInputStep,
QwenImageTextEncoderStep,
QwenImageVaeEncoderDynamicStep,
)
from .inputs import (
QwenImageControlNetInputsStep,
QwenImageInputsDynamicStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# 1. VAE ENCODER
# inpaint vae encoder
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderDynamicStep()]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return (
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
" - Resizes the image to the target size, based on `height` and `width`.\n"
" - Processes and updates `image` and `mask_image`.\n"
" - Creates `image_latents`."
)
# img2img vae encoder
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderDynamicStep()]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
# auto vae encoder
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ " - if `mask_image` or `image` is not provided, step will be skipped."
)
# optional controlnet vae encoder
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageControlNetVaeEncoderStep]
block_names = ["controlnet"]
block_trigger_inputs = ["control_image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ " - if `control_image` is not provided, step will be skipped."
)
# 2. DENOISE
# input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise
# img2img input
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageTextInputsStep(), QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
# inpaint input
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageTextInputsStep(), QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"])]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
# inpaint prepare latents
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
block_names = ["add_noise_to_latents", "create_mask_latents"]
@property
def description(self) -> str:
return (
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
" - Add noise to the image latents to create the latents input for the denoiser.\n"
" - Create the pachified latents `mask` based on the processedmask image.\n"
)
# CoreDenoiseStep:
# (input + prepare_latents + set_timesteps + prepare_rope_inputs + denoise + after_denoise)
# 1. text2image
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageRoPEInputsStep(),
QwenImageDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
# 2.inpaint
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageInpaintInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImageInpaintPrepareLatentsStep(),
QwenImageRoPEInputsStep(),
QwenImageInpaintDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_inpaint_latents",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
# 3. img2img
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageImg2ImgInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImagePrepareLatentsWithStrengthStep(),
QwenImageRoPEInputsStep(),
QwenImageDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_img2img_latents",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
# 4. text2image + controlnet
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
QwenImageControlNetInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageRoPEInputsStep(),
QwenImageControlNetBeforeDenoiserStep(),
QwenImageControlNetDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"controlnet_input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"controlnet_before_denoise",
"controlnet_denoise",
"after_denoise",
]
@property
def description(self):
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
# 5. inpaint + controlnet
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageInpaintInputStep(),
QwenImageControlNetInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImageInpaintPrepareLatentsStep(),
QwenImageRoPEInputsStep(),
QwenImageControlNetBeforeDenoiserStep(),
QwenImageInpaintControlNetDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"controlnet_input",
"prepare_latents",
"set_timesteps",
"prepare_inpaint_latents",
"prepare_rope_inputs",
"controlnet_before_denoise",
"controlnet_denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
# 6. img2img + controlnet
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageImg2ImgInputStep(),
QwenImageControlNetInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImagePrepareLatentsWithStrengthStep(),
QwenImageRoPEInputsStep(),
QwenImageControlNetBeforeDenoiserStep(),
QwenImageControlNetDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"controlnet_input",
"prepare_latents",
"set_timesteps",
"prepare_img2img_latents",
"prepare_rope_inputs",
"controlnet_before_denoise",
"controlnet_denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
# auto denoise
# auto denoise step for controlnet tasks: works for all tasks with controlnet
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
block_classes = [
QwenImageCoreDenoiseStep,
QwenImageInpaintCoreDenoiseStep,
QwenImageImg2ImgCoreDenoiseStep,
QwenImageControlNetCoreDenoiseStep,
QwenImageControlNetInpaintCoreDenoiseStep,
QwenImageControlNetImg2ImgCoreDenoiseStep,
]
block_names = [
"text2image",
"inpaint",
"img2img",
"controlnet_text2image",
"controlnet_inpaint",
"controlnet_img2img"]
block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"]
default_block_name = "text2image"
def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None):
if control_image_latents is not None:
if processed_mask_image is not None:
return "controlnet_inpaint"
elif image_latents is not None:
return "controlnet_img2img"
else:
return "controlnet_text2image"
else:
if processed_mask_image is not None:
return "inpaint"
elif image_latents is not None:
return "img2img"
else:
return "text2image"
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n"
+ " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n"
+ " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n"
+ " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n"
+ " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n"
+ " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n"
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
+ " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
)
# 4. DECODE
## 1.1 text2image
#### decode
#### (standard decode step works for most tasks except for inpaint)
class QwenImageDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image."
#### inpaint decode
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
# auto decode step for inpaint and text2image tasks
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
block_names = ["inpaint_decode", "decode"]
block_trigger_inputs = ["mask", None]
@property
def description(self):
return (
"Decode step that decode the latents into images. \n"
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
)
## 1.10 QwenImage/auto block & presets
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("vae_encoder", QwenImageAutoVaeEncoderStep()),
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
("denoise", QwenImageAutoCoreDenoiseStep()),
("decode", QwenImageAutoDecodeStep()),
]
)
class QwenImageAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)

View File

@@ -0,0 +1,329 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageCreateMaskLatentsStep,
QwenImageEditRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
QwenImageSetTimestepsStep,
QwenImageSetTimestepsWithStrengthStep,
)
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageInpaintProcessImagesOutputStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageEditDenoiseStep,
QwenImageEditInpaintDenoiseStep,
)
from .encoders import (
QwenImageEditResizeDynamicStep,
QwenImageEditTextEncoderStep,
QwenImageInpaintProcessImagesInputStep,
QwenImageProcessImagesInputStep,
QwenImageVaeEncoderDynamicStep,
)
from .inputs import (
QwenImageInputsDynamicStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# ====================
# 1. TEXT ENCODER
# ====================
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
"""VL encoder that takes both image and text prompts."""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeDynamicStep(),
QwenImageEditTextEncoderStep(),
]
block_names = ["resize", "encode"]
@property
def description(self) -> str:
return "QwenImage-Edit VL encoder step that encode the image and text prompts together."
# ====================
# 2. VAE ENCODER
# ====================
# Edit VAE encoder
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeDynamicStep(),
QwenImageProcessImagesInputStep(),
QwenImageVaeEncoderDynamicStep(),
]
block_names = ["resize", "preprocess", "encode"]
@property
def description(self) -> str:
return "Vae encoder step that encode the image inputs into their latent representations."
# Edit Inpaint VAE encoder
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeDynamicStep(),
QwenImageInpaintProcessImagesInputStep(),
QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
]
block_names = ["resize", "preprocess", "encode"]
@property
def description(self) -> str:
return (
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
" - process the resized image and mask image.\n"
" - create image latents."
)
# Auto VAE encoder
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
"This is an auto pipeline block.\n"
" - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
" - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
" - if `mask_image` or `image` is not provided, step will be skipped."
)
# ====================
# 3. DENOISE - input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise
# ====================
# Edit input step
class QwenImageEditInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"]),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the edit denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
# Edit Inpaint input step
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the edit inpaint denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
# Edit Inpaint prepare latents step
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
block_names = ["add_noise_to_latents", "create_mask_latents"]
@property
def description(self) -> str:
return (
"This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n"
" - Add noise to the image latents to create the latents input for the denoiser.\n"
" - Create the patchified latents `mask` based on the processed mask image.\n"
)
# 1. Edit (img2img) core denoise
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageEditRoPEInputsStep(),
QwenImageEditDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
# 2. Edit Inpaint core denoise
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInpaintInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImageEditInpaintPrepareLatentsStep(),
QwenImageEditRoPEInputsStep(),
QwenImageEditInpaintDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_inpaint_latents",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
# Auto core denoise step
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
block_classes = [
QwenImageEditInpaintCoreDenoiseStep,
QwenImageEditCoreDenoiseStep,
]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["processed_mask_image", "image_latents"]
default_block_name = "edit"
def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]:
if processed_mask_image is not None:
return "edit_inpaint"
elif image_latents is not None:
return "edit"
return None
@property
def description(self):
return (
"Auto core denoising step that selects the appropriate workflow based on inputs.\n"
" - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n"
" - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n"
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
)
# ====================
# 4. DECODE
# ====================
# Decode step (standard)
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image."
# Inpaint decode step
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image."
# Auto decode step
class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep]
block_names = ["inpaint_decode", "decode"]
block_trigger_inputs = ["mask", None]
@property
def description(self):
return (
"Decode step that decode the latents into images.\n"
"This is an auto pipeline block.\n"
" - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
)
# ====================
# 5. AUTO BLOCKS & PRESETS
# ====================
EDIT_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditVLEncoderStep()),
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
("denoise", QwenImageEditAutoCoreDenoiseStep()),
("decode", QwenImageEditAutoDecodeStep()),
]
)
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = EDIT_AUTO_BLOCKS.values()
block_names = EDIT_AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
"- for edit (img2img) generation, you need to provide `image`\n"
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
)

View File

@@ -0,0 +1,175 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageEditPlusRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImageSetTimestepsStep,
)
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageEditDenoiseStep,
)
from .encoders import (
QwenImageEditPlusResizeDynamicStep,
QwenImageEditPlusTextEncoderStep,
QwenImageEditPlusProcessImagesInputStep,
QwenImageVaeEncoderDynamicStep,
)
from .inputs import (
QwenImageEditPlusInputsDynamicStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# ====================
# 1. TEXT ENCODER
# ====================
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusResizeDynamicStep(target_area=384 * 384, output_name="resized_cond_image"),
QwenImageEditPlusTextEncoderStep(),
]
block_names = ["resize", "encode"]
@property
def description(self) -> str:
return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together."
# ====================
# 2. VAE ENCODER
# ====================
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusResizeDynamicStep(target_area=1024 * 1024, output_name="resized_image"),
QwenImageEditPlusProcessImagesInputStep(),
QwenImageVaeEncoderDynamicStep(),
]
block_names = ["resize", "preprocess", "encode"]
@property
def description(self) -> str:
return (
"VAE encoder step that encodes image inputs into latent representations.\n"
"Each image is resized independently based on its own aspect ratio to 1024x1024 target area."
)
# ====================
# 3. DENOISE - input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise
# ====================
# Edit Plus input step
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageTextInputsStep(),
QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the Edit Plus denoising step. It:\n"
" - Standardizes text embeddings batch size.\n"
" - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n"
" - Outputs lists of image_height/image_width for RoPE calculation.\n"
" - Defaults height/width from last image in the list."
)
# Edit Plus core denoise
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageEditPlusRoPEInputsStep(),
QwenImageEditDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
# ====================
# 4. DECODE
# ====================
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocesses the generated image."
# ====================
# 5. AUTO BLOCKS & PRESETS
# ====================
EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditPlusVLEncoderStep()),
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
("denoise", QwenImageEditPlusCoreDenoiseStep()),
("decode", QwenImageEditPlusDecodeStep()),
]
)
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n"
"- `image` is required input (can be single image or list of images).\n"
"- Each image is resized independently based on its own aspect ratio.\n"
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
)

View File

@@ -411,7 +411,6 @@ else:
"ZImagePipeline",
"ZImageControlNetPipeline",
"ZImageControlNetInpaintPipeline",
"ZImageOmniPipeline",
]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
@@ -857,7 +856,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)

View File

@@ -73,7 +73,6 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .lumina import LuminaPipeline
from .lumina2 import Lumina2Pipeline
from .ovis_image import OvisImagePipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
@@ -120,13 +119,7 @@ from .stable_diffusion_xl import (
)
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
from .z_image import (
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -171,10 +164,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
("z-image", ZImagePipeline),
("z-image-controlnet", ZImageControlNetPipeline),
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
("z-image-omni", ZImageOmniPipeline),
("ovis", OvisImagePipeline),
]
)

View File

@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -229,7 +229,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
HunyuanDiT2DMultiControlNetModel,
],
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[T5Tokenizer] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
requires_safety_checker: bool = True,
):
super().__init__()

View File

@@ -133,7 +133,7 @@ EXAMPLE_DOC_STRING = """
... num_frames=93,
... generator=torch.Generator().manual_seed(1),
... ).frames[0]
>>> export_to_video(video, "image2world.mp4", fps=16)
>>> # export_to_video(video, "image2world.mp4", fps=16)
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
>>> prompt = (

View File

@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -204,7 +204,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[T5Tokenizer] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
):
super().__init__()

View File

@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -208,7 +208,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
feature_extractor: Optional[CLIPImageProcessor] = None,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[T5Tokenizer] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
):
super().__init__()

View File

@@ -26,7 +26,6 @@ else:
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -42,7 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
from .pipeline_z_image_omni import ZImageOmniPipeline
else:
import sys

View File

@@ -1,742 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import PIL
import torch
from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import ZImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..flux2.image_processor import Flux2ImageProcessor
from .pipeline_output import ZImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import ZImageOmniPipeline
>>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
>>> # (1) Use flash attention 2
>>> # pipe.transformer.set_attention_backend("flash")
>>> # (2) Use flash attention 3
>>> # pipe.transformer.set_attention_backend("_flash_3")
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化一辆复古蒸汽小火车化身为巨大的拉链头正拉开厚厚的冬日积雪展露出一个生机盎然的春天。"
>>> image = pipe(
... prompt,
... height=1024,
... width=1024,
... num_inference_steps=9,
... guidance_scale=0.0,
... generator=torch.Generator("cuda").manual_seed(42),
... ).images[0]
>>> image.save("zimage.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: PreTrainedModel,
tokenizer: AutoTokenizer,
transformer: ZImageTransformer2DModel,
siglip: Siglip2VisionModel,
siglip_processor: Siglip2ImageProcessorFast,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
siglip=siglip,
siglip_processor=siglip_processor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
# self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
num_condition_images: int = 0,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
num_condition_images=num_condition_images,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
num_condition_images=num_condition_images,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
num_condition_images: int = 0,
) -> List[torch.FloatTensor]:
device = device or self._execution_device
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
if num_condition_images == 0:
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
elif num_condition_images > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|im_end|>"]
prompt[i] = prompt_list
flattened_prompt = []
prompt_list_lengths = []
for i in range(len(prompt)):
prompt_list_lengths.append(len(prompt[i]))
flattened_prompt.extend(prompt[i])
text_inputs = self.tokenizer(
flattened_prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
start_idx = 0
for i in range(len(prompt_list_lengths)):
batch_embeddings = []
end_idx = start_idx + prompt_list_lengths[i]
for j in range(start_idx, end_idx):
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
embeddings_list.append(batch_embeddings)
start_idx = end_idx
return embeddings_list
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
def prepare_image_latents(
self,
images: List[torch.Tensor],
batch_size,
device,
dtype,
):
image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
image_latent = (
self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor
) * self.vae.config.scaling_factor
image_latent = image_latent.unsqueeze(1).to(dtype)
image_latents.append(image_latent) # (16, 128, 128)
# image_latents = [image_latents] * batch_size
image_latents = [image_latents.copy() for _ in range(batch_size)]
return image_latents
def prepare_siglip_embeds(
self,
images: List[torch.Tensor],
batch_size,
device,
dtype,
):
siglip_embeds = []
for image in images:
siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device)
shape = siglip_inputs.spatial_shapes[0]
hidden_state = self.siglip(**siglip_inputs).last_hidden_state
B, N, C = hidden_state.shape
hidden_state = hidden_state[:, : shape[0] * shape[1]]
hidden_state = hidden_state.view(shape[0], shape[1], C)
siglip_embeds.append(hidden_state.to(dtype))
# siglip_embeds = [siglip_embeds] * batch_size
siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)]
return siglip_embeds
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_normalization (`bool`, *optional*, defaults to False):
Whether to apply configuration normalization.
cfg_truncation (`float`, *optional*, defaults to 1.0):
The truncation value for configuration.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
if image is not None and not isinstance(image, list):
image = [image]
num_condition_images = len(image) if image is not None else 0
device = self._execution_device
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
self._cfg_normalization = cfg_normalization
self._cfg_truncation = cfg_truncation
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided without `prompt`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
else:
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
num_condition_images=num_condition_images,
)
# 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2
condition_images = []
resized_images = []
if image is not None:
for img in image:
self.image_processor.check_image_input(img)
for img in image:
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
if height is not None and width is not None:
img = self.image_processor._resize_to_target_area(img, height * width)
else:
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
resized_images.append(img)
multiple_of = self.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
condition_images.append(img)
if len(condition_images) > 0:
height = height or image_height
width = width or image_width
else:
height = height or 1024
width = width or 1024
vae_scale = self.vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(
f"Height must be divisible by {vae_scale} (got {height}). "
f"Please adjust the height to a multiple of {vae_scale}."
)
if width % vae_scale != 0:
raise ValueError(
f"Width must be divisible by {vae_scale} (got {width}). "
f"Please adjust the width to a multiple of {vae_scale}."
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
torch.float32,
device,
generator,
latents,
)
condition_latents = self.prepare_image_latents(
images=condition_images,
batch_size=batch_size * num_images_per_prompt,
device=device,
dtype=torch.float32,
)
condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents]
if self.do_classifier_free_guidance:
negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents]
condition_siglip_embeds = self.prepare_siglip_embeds(
images=resized_images,
batch_size=batch_size * num_images_per_prompt,
device=device,
dtype=torch.float32,
)
condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds]
if self.do_classifier_free_guidance:
negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds]
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds]
negative_condition_siglip_embeds = [
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds
]
actual_batch_size = batch_size * num_images_per_prompt
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
# 5. Prepare timesteps
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
self.scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
# Normalized time for time-aware config (0 at start, 1 at end)
t_norm = timestep[0].item()
# Handle cfg truncation
current_guidance_scale = self.guidance_scale
if (
self.do_classifier_free_guidance
and self._cfg_truncation is not None
and float(self._cfg_truncation) <= 1
):
if t_norm > self._cfg_truncation:
current_guidance_scale = 0.0
# Run CFG only if configured AND scale is non-zero
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(self.transformer.dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
condition_latents_model_input = condition_latents + negative_condition_latents
condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
condition_latents_model_input = condition_latents
condition_siglip_embeds_model_input = condition_siglip_embeds
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
# Combine condition latents with target latent
current_batch_size = len(latent_model_input_list)
x_combined = [
condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size)
]
# Create noise mask: 0 for condition images (clean), 1 for target image (noisy)
image_noise_mask = [
[0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size)
]
model_out_list = self.transformer(
x=x_combined,
t=timestep_model_input,
cap_feats=prompt_embeds_model_input,
siglip_feats=condition_siglip_embeds_model_input,
image_noise_mask=image_noise_mask,
return_dict=False,
)[0]
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
noise_pred = []
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
# Renormalization
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -36,9 +36,6 @@ from ...utils import (
from ..base import DiffusersQuantizer
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
@@ -86,19 +83,11 @@ def _update_torch_safe_globals():
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])
# note: is_torchao_version(">=", "0.16.0") does not work correctly
# with torchao nightly, so using a ">" check which does work correctly
if is_torchao_version(">", "0.15.0"):
pass
else:
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
@@ -134,6 +123,9 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
return None
logger = logging.get_logger(__name__)
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

View File

@@ -3917,21 +3917,6 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ZImageOmniPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ZImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]