mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 21:34:55 +08:00
Compare commits
5 Commits
modular-qw
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6b6a7181e | ||
|
|
52766e6a69 | ||
|
|
973a077c6a | ||
|
|
0c4f6c9cff | ||
|
|
262ce19bff |
@@ -21,8 +21,8 @@ from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
CLIPImageProcessor,
|
||||
MT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
tokenizer_2=T5Tokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -29,13 +29,52 @@ hf download nvidia/Cosmos-Predict2.5-2B
|
||||
|
||||
Convert checkpoint
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/cosmos-p2.5-base-2b \
|
||||
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
## 14B
|
||||
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-14B
|
||||
```
|
||||
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
@@ -298,6 +337,25 @@ TRANSFORMER_CONFIGS = {
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-14B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
|
||||
@@ -675,6 +675,7 @@ else:
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
@@ -1386,6 +1387,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ if is_torch_available():
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .guider_utils import BaseGuidance
|
||||
from .magnitude_aware_guidance import MagnitudeAwareGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
|
||||
159
src/diffusers/guiders/magnitude_aware_guidance.py
Normal file
159
src/diffusers/guiders/magnitude_aware_guidance.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class MagnitudeAwareGuidance(BaseGuidance):
|
||||
"""
|
||||
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `10.0`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
alpha (`float`, defaults to `8.0`):
|
||||
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
|
||||
guidance scale when the magnitude of the guidance update is large.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 10.0,
|
||||
alpha: float = 8.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.alpha = alpha
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
if not self._is_mambo_g_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = mambo_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.alpha,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_mambo_g_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_mambo_g_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def mambo_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
alpha: float = 8.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
dim = list(range(1, len(pred_cond.shape)))
|
||||
diff = pred_cond - pred_uncond
|
||||
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
|
||||
guidance_scale_final = (
|
||||
guidance_scale * torch.exp(-alpha * ratio)
|
||||
if use_original_formulation
|
||||
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
|
||||
)
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale_final * diff
|
||||
|
||||
return pred
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -170,6 +170,21 @@ class FeedForward(nn.Module):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
@@ -220,12 +235,37 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
@@ -493,112 +533,93 @@ class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
cap_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_pad_mask.append(
|
||||
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padded_pos_ids = torch.cat(
|
||||
[
|
||||
image_ori_pos_ids,
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1),
|
||||
],
|
||||
dim=0,
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
||||
# pad mask
|
||||
image_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_image_pad_mask.append(
|
||||
image_pad_mask
|
||||
if image_padding_len > 0
|
||||
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat(
|
||||
[image, image[-1:].repeat(image_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -32,6 +32,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
X_PAD_DIM = 64
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -152,6 +153,20 @@ class ZSingleStreamAttnProcessor:
|
||||
return output
|
||||
|
||||
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
@@ -215,12 +230,37 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
@@ -252,9 +292,21 @@ class FinalLayer(nn.Module):
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation
|
||||
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
|
||||
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
|
||||
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Original global modulation
|
||||
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
scale = scale.unsqueeze(1)
|
||||
|
||||
x = self.norm_final(x) * scale
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -325,6 +377,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
@@ -386,6 +439,31 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
|
||||
|
||||
# Optional SigLIP components (for Omni variant)
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
2000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
@@ -402,259 +480,561 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
def unpatchify(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
size: List[Tuple],
|
||||
patch_size,
|
||||
f_patch_size,
|
||||
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
if x_pos_offsets is not None:
|
||||
# Omni: extract target image from unified sequence (cond_images + target)
|
||||
result = []
|
||||
for i in range(bsz):
|
||||
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
|
||||
cu_len = 0
|
||||
x_item = None
|
||||
for j in range(len(size[i])):
|
||||
if size[i][j] is None:
|
||||
ori_len = 0
|
||||
pad_len = SEQ_MULTI_OF
|
||||
cu_len += pad_len + ori_len
|
||||
else:
|
||||
F, H, W = size[i][j]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
x_item = (
|
||||
unified_x[cu_len : cu_len + ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
cu_len += ori_len + pad_len
|
||||
result.append(x_item) # Return only the last (target) image
|
||||
return result
|
||||
else:
|
||||
# Original mode: simple unpatchify
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
def patchify_and_embed(
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
cap_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_pad_mask.append(
|
||||
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padded_pos_ids = torch.cat(
|
||||
[
|
||||
image_ori_pos_ids,
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1),
|
||||
],
|
||||
dim=0,
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
||||
# pad mask
|
||||
image_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_image_pad_mask.append(
|
||||
image_pad_mask
|
||||
if image_padding_len > 0
|
||||
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat(
|
||||
[image, image[-1:].repeat(image_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
def patchify_and_embed_omni(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
return_dict: bool = True,
|
||||
all_x: List[List[torch.Tensor]],
|
||||
all_cap_feats: List[List[torch.Tensor]],
|
||||
all_siglip_feats: List[List[torch.Tensor]],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
images_noise_mask: List[List[int]],
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
"""Patchify for omni mode: multiple images per batch item with noise masks."""
|
||||
bsz = len(all_x)
|
||||
device = all_x[0][-1].device
|
||||
dtype = all_x[0][-1].dtype
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
|
||||
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
for i in range(bsz):
|
||||
num_images = len(all_x[i])
|
||||
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
|
||||
cap_end_pos = []
|
||||
cap_cu_len = 1
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
# Process captions
|
||||
for j, cap_item in enumerate(all_cap_feats[i]):
|
||||
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
|
||||
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
|
||||
cap_item,
|
||||
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
|
||||
(cap_cu_len, 0, 0),
|
||||
device,
|
||||
noise_val,
|
||||
)
|
||||
cap_feats_list.append(cap_out)
|
||||
cap_pos_list.append(cap_pos)
|
||||
cap_mask_list.append(cap_mask)
|
||||
cap_lens.append(cap_len)
|
||||
cap_noise.extend(cap_nm)
|
||||
cap_cu_len += len(cap_item)
|
||||
cap_end_pos.append(cap_cu_len)
|
||||
cap_cu_len += 2 # for image vae and siglip tokens
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
|
||||
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
|
||||
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
|
||||
all_cap_len.append(cap_lens)
|
||||
all_cap_noise_mask.append(cap_noise)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
||||
# Process images
|
||||
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
|
||||
for j, x_item in enumerate(all_x[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if x_item is not None:
|
||||
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
|
||||
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
|
||||
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
|
||||
)
|
||||
x_size.append(size)
|
||||
else:
|
||||
x_len = SEQ_MULTI_OF
|
||||
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
|
||||
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
|
||||
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
|
||||
x_nm = [noise_val] * x_len
|
||||
x_size.append(None)
|
||||
x_feats_list.append(x_out)
|
||||
x_pos_list.append(x_pos)
|
||||
x_mask_list.append(x_mask)
|
||||
x_lens.append(x_len)
|
||||
x_noise.extend(x_nm)
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
||||
all_x_out.append(torch.cat(x_feats_list, dim=0))
|
||||
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
|
||||
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
|
||||
all_x_size.append(x_size)
|
||||
all_x_len.append(x_lens)
|
||||
all_x_noise_mask.append(x_noise)
|
||||
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
# Process siglip
|
||||
if all_siglip_feats[i] is None:
|
||||
all_sig_len.append([0] * num_images)
|
||||
all_sig_out.append(None)
|
||||
else:
|
||||
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
|
||||
for j, sig_item in enumerate(all_siglip_feats[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if sig_item is not None:
|
||||
sig_H, sig_W, sig_C = sig_item.size()
|
||||
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
|
||||
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
|
||||
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
|
||||
)
|
||||
# Scale position IDs to match x resolution
|
||||
if x_size[j] is not None:
|
||||
sig_pos = sig_pos.float()
|
||||
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
|
||||
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
|
||||
sig_pos = sig_pos.to(torch.int32)
|
||||
else:
|
||||
sig_len = SEQ_MULTI_OF
|
||||
sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device)
|
||||
sig_pos = (
|
||||
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
|
||||
)
|
||||
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
|
||||
sig_nm = [noise_val] * sig_len
|
||||
sig_feats_list.append(sig_out)
|
||||
sig_pos_list.append(sig_pos)
|
||||
sig_mask_list.append(sig_mask)
|
||||
sig_lens.append(sig_len)
|
||||
sig_noise.extend(sig_nm)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
|
||||
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
|
||||
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
|
||||
all_sig_len.append(sig_lens)
|
||||
all_sig_noise_mask.append(sig_noise)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
# Compute x position offsets
|
||||
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(
|
||||
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
||||
return (
|
||||
all_x_out,
|
||||
all_cap_out,
|
||||
all_sig_out,
|
||||
all_x_size,
|
||||
all_x_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_sig_pos_ids,
|
||||
all_x_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
all_sig_pad_mask,
|
||||
all_x_pos_offsets,
|
||||
all_x_noise_mask,
|
||||
all_cap_noise_mask,
|
||||
all_sig_noise_mask,
|
||||
)
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
feats: List[torch.Tensor],
|
||||
pos_ids: List[torch.Tensor],
|
||||
inner_pad_mask: List[torch.Tensor],
|
||||
pad_token: torch.nn.Parameter,
|
||||
noise_mask: Optional[List[List[int]]] = None,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
|
||||
item_seqlens = [len(f) for f in feats]
|
||||
max_seqlen = max(item_seqlens)
|
||||
bsz = len(feats)
|
||||
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
# RoPE
|
||||
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
|
||||
|
||||
# unified
|
||||
# Pad to batch
|
||||
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if noise_mask is not None:
|
||||
noise_mask_tensor = pad_sequence(
|
||||
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)[:, : feats.shape[1]]
|
||||
|
||||
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
|
||||
|
||||
def _build_unified_sequence(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_freqs: torch.Tensor,
|
||||
x_seqlens: List[int],
|
||||
x_noise_mask: Optional[List[List[int]]],
|
||||
cap: torch.Tensor,
|
||||
cap_freqs: torch.Tensor,
|
||||
cap_seqlens: List[int],
|
||||
cap_noise_mask: Optional[List[List[int]]],
|
||||
siglip: Optional[torch.Tensor],
|
||||
siglip_freqs: Optional[torch.Tensor],
|
||||
siglip_seqlens: Optional[List[int]],
|
||||
siglip_noise_mask: Optional[List[List[int]]],
|
||||
omni_mode: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Build unified sequence: x, cap, and optionally siglip.
|
||||
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
|
||||
"""
|
||||
bsz = len(x_seqlens)
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
unified_freqs = []
|
||||
unified_noise_mask = []
|
||||
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
if omni_mode:
|
||||
# Omni: [cap, x, siglip]
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
sig_len = siglip_seqlens[i]
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
|
||||
unified_freqs.append(
|
||||
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
|
||||
)
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(
|
||||
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
|
||||
)
|
||||
)
|
||||
else:
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
|
||||
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
|
||||
)
|
||||
else:
|
||||
# Basic: [x, cap]
|
||||
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
|
||||
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = self._gradient_checkpointing_func(
|
||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
if controlnet_block_samples is not None:
|
||||
if layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
# Compute unified seqlens
|
||||
if omni_mode:
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
|
||||
else:
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
if controlnet_block_samples is not None:
|
||||
if layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
max_seqlen = max(unified_seqlens)
|
||||
|
||||
if not return_dict:
|
||||
return (x,)
|
||||
# Pad to batch
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
return Transformer2DModelOutput(sample=x)
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if omni_mode:
|
||||
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
|
||||
:, : unified.shape[1]
|
||||
]
|
||||
|
||||
return unified, unified_freqs, attn_mask, noise_mask_tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[List[torch.Tensor], List[List[torch.Tensor]]],
|
||||
t,
|
||||
cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]],
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
|
||||
siglip_feats: Optional[List[List[torch.Tensor]]] = None,
|
||||
image_noise_mask: Optional[List[List[int]]] = None,
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine
|
||||
-> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify
|
||||
"""
|
||||
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
|
||||
omni_mode = isinstance(x[0], list)
|
||||
device = x[0][-1].device if omni_mode else x[0].device
|
||||
|
||||
if omni_mode:
|
||||
# Dual embeddings: noisy (t) and clean (t=1)
|
||||
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
|
||||
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
|
||||
adaln_input = None
|
||||
else:
|
||||
# Single embedding for all tokens
|
||||
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
|
||||
t_noisy = t_clean = None
|
||||
|
||||
# Patchify
|
||||
if omni_mode:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
siglip_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
siglip_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
siglip_pad_mask,
|
||||
x_pos_offsets,
|
||||
x_noise_mask,
|
||||
cap_noise_mask,
|
||||
siglip_noise_mask,
|
||||
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
|
||||
else:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
|
||||
|
||||
# X embed & refine
|
||||
x_seqlens = [len(xi) for xi in x]
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
|
||||
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
|
||||
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
|
||||
)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = (
|
||||
self._gradient_checkpointing_func(
|
||||
layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean)
|
||||
)
|
||||
|
||||
# Cap embed & refine
|
||||
cap_seqlens = [len(ci) for ci in cap_feats]
|
||||
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
|
||||
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
|
||||
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = (
|
||||
self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(cap_feats, cap_mask, cap_freqs)
|
||||
)
|
||||
|
||||
# Siglip embed & refine
|
||||
siglip_seqlens = siglip_freqs = None
|
||||
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
|
||||
siglip_seqlens = [len(si) for si in siglip_feats]
|
||||
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
|
||||
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
|
||||
list(siglip_feats.split(siglip_seqlens, dim=0)),
|
||||
siglip_pos_ids,
|
||||
siglip_pad_mask,
|
||||
self.siglip_pad_token,
|
||||
None,
|
||||
device,
|
||||
)
|
||||
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats = (
|
||||
self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(siglip_feats, siglip_mask, siglip_freqs)
|
||||
)
|
||||
|
||||
# Unified sequence
|
||||
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
|
||||
x,
|
||||
x_freqs,
|
||||
x_seqlens,
|
||||
x_noise_mask,
|
||||
cap_feats,
|
||||
cap_freqs,
|
||||
cap_seqlens,
|
||||
cap_noise_mask,
|
||||
siglip_feats,
|
||||
siglip_freqs,
|
||||
siglip_seqlens,
|
||||
siglip_noise_mask,
|
||||
omni_mode,
|
||||
device,
|
||||
)
|
||||
|
||||
# Main transformer layers
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = (
|
||||
self._gradient_checkpointing_func(
|
||||
layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean)
|
||||
)
|
||||
if controlnet_block_samples is not None and layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
|
||||
unified = (
|
||||
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
|
||||
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
|
||||
)
|
||||
if omni_mode
|
||||
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
|
||||
)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
|
||||
|
||||
return (x,) if not return_dict else Transformer2DModelOutput(sample=x)
|
||||
|
||||
@@ -231,7 +231,7 @@ class BlockState:
|
||||
|
||||
class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks,
|
||||
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks
|
||||
|
||||
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
|
||||
@@ -527,10 +527,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
|
||||
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that conditionally selects a block to run based on the inputs.
|
||||
Subclasses must implement the `select_block` method to define the logic for selecting the block.
|
||||
A Pipeline Blocks that automatically selects a block to run based on the inputs.
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
@@ -540,13 +539,12 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
Attributes:
|
||||
block_classes: List of block classes to be used
|
||||
block_names: List of prefixes for each block
|
||||
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
|
||||
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
|
||||
"""
|
||||
|
||||
block_classes = []
|
||||
block_names = []
|
||||
block_trigger_inputs = []
|
||||
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
@@ -556,15 +554,26 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
if not (len(self.block_classes) == len(self.block_names)):
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same."
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
)
|
||||
if self.default_block_name is not None and self.default_block_name not in self.block_names:
|
||||
default_blocks = [t for t in self.block_trigger_inputs if t is None]
|
||||
# can only have 1 or 0 default block, and has to put in the last
|
||||
# the order of blocks matters here because the first block with matching trigger will be dispatched
|
||||
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
|
||||
# as long as mask is provided, it is inpaint; if only image is provided, it is img2img
|
||||
if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}"
|
||||
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
|
||||
"in block_trigger_inputs."
|
||||
)
|
||||
|
||||
# Map trigger inputs to block objects
|
||||
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
|
||||
self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
|
||||
self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
return next(iter(self.sub_blocks.values())).model_name
|
||||
@@ -593,11 +602,8 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
|
||||
# no default block means this conditional block can be skipped entirely
|
||||
if self.default_block_name is None:
|
||||
if None not in self.block_trigger_inputs:
|
||||
return []
|
||||
|
||||
first_block = next(iter(self.sub_blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_inputs", set()))
|
||||
|
||||
@@ -608,7 +614,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
@@ -633,69 +639,22 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
def _get_trigger_inputs(self) -> set:
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in this block and nested blocks.
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
trigger_values = set()
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has block_trigger_inputs
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
if block.sub_blocks:
|
||||
nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
|
||||
trigger_values.update(nested_triggers)
|
||||
|
||||
return trigger_values
|
||||
|
||||
# Start with this block's block_trigger_inputs
|
||||
all_triggers = set(t for t in self.block_trigger_inputs if t is not None)
|
||||
# Add nested triggers
|
||||
all_triggers.update(fn_recursive_get_trigger(self.sub_blocks))
|
||||
|
||||
return all_triggers
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
"""All trigger inputs including from nested blocks."""
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Select the block to run based on the trigger inputs.
|
||||
Subclasses must implement this method to define the logic for selecting the block.
|
||||
|
||||
Args:
|
||||
**kwargs: Trigger input names and their values from the state.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The name of the block to run, or None to use default/skip.
|
||||
"""
|
||||
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
|
||||
trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None}
|
||||
block_name = self.select_block(**trigger_kwargs)
|
||||
# Find default block first (if any)
|
||||
|
||||
if block_name is None:
|
||||
block_name = self.default_block_name
|
||||
block = self.trigger_to_block_map.get(None)
|
||||
for input_name in self.block_trigger_inputs:
|
||||
if input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
|
||||
if block_name is None:
|
||||
logger.info(f"skipping conditional block: {self.__class__.__name__}")
|
||||
if block is None:
|
||||
logger.info(f"skipping auto block: {self.__class__.__name__}")
|
||||
return pipeline, state
|
||||
|
||||
block = self.sub_blocks[block_name]
|
||||
|
||||
try:
|
||||
logger.info(f"Running block: {block.__class__.__name__}")
|
||||
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
|
||||
return block(pipeline, state)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
@@ -706,6 +665,38 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _get_trigger_inputs(self):
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
|
||||
block_trigger_inputs values
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
trigger_values = set()
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has trigger inputs(i.e. auto block)
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
# Add all non-None values from the trigger inputs list
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
if block.sub_blocks:
|
||||
nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
|
||||
trigger_values.update(nested_triggers)
|
||||
|
||||
return trigger_values
|
||||
|
||||
trigger_inputs = set(self.block_trigger_inputs)
|
||||
trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
|
||||
|
||||
return trigger_inputs
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
@@ -717,7 +708,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -738,20 +729,31 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
|
||||
# Blocks section
|
||||
# Blocks section - moved to the end with simplified format
|
||||
blocks_str = " Sub-Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.sub_blocks.items()):
|
||||
if name == self.default_block_name:
|
||||
addtional_str = " [default]"
|
||||
# Get trigger input for this block
|
||||
trigger = None
|
||||
if hasattr(self, "block_to_trigger_map"):
|
||||
trigger = self.block_to_trigger_map.get(name)
|
||||
# Format the trigger info
|
||||
if trigger is None:
|
||||
trigger_str = "[default]"
|
||||
elif isinstance(trigger, (list, tuple)):
|
||||
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
|
||||
else:
|
||||
trigger_str = f"[trigger: {trigger}]"
|
||||
# For AutoPipelineBlocks, add bullet points
|
||||
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
|
||||
else:
|
||||
addtional_str = ""
|
||||
blocks_str += f" • {name}{addtional_str} ({block.__class__.__name__})\n"
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
block_desc_lines = block.description.split("\n")
|
||||
indented_desc = block_desc_lines[0]
|
||||
if len(block_desc_lines) > 1:
|
||||
indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:])
|
||||
desc_lines = block.description.split("\n")
|
||||
indented_desc = desc_lines[0]
|
||||
if len(desc_lines) > 1:
|
||||
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
|
||||
blocks_str += f" Description: {indented_desc}\n\n"
|
||||
|
||||
# Build the representation with conditional sections
|
||||
@@ -782,35 +784,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_block_name(self) -> Optional[str]:
|
||||
"""Derive default_block_name from block_trigger_inputs (None entry)."""
|
||||
if None in self.block_trigger_inputs:
|
||||
idx = self.block_trigger_inputs.index(None)
|
||||
return self.block_names[idx]
|
||||
return None
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""Select block based on which trigger input is present (not None)."""
|
||||
for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names):
|
||||
if trigger_input is not None and kwargs.get(trigger_input) is not None:
|
||||
return block_name
|
||||
return None
|
||||
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
|
||||
@@ -912,8 +885,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None:
|
||||
# ConditionalPipelineBlocks without default can be skipped
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
@@ -976,7 +948,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def _get_trigger_inputs(self):
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks.
|
||||
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
|
||||
block_trigger_inputs values
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
@@ -984,8 +957,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has block_trigger_inputs (ConditionalPipelineBlocks)
|
||||
# Check if current block has trigger inputs(i.e. auto block)
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
# Add all non-None values from the trigger inputs list
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
@@ -1001,85 +975,82 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def trigger_inputs(self):
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def _traverse_trigger_blocks(self, active_inputs):
|
||||
"""
|
||||
Traverse blocks and select which ones would run given the active inputs.
|
||||
def _traverse_trigger_blocks(self, trigger_inputs):
|
||||
# Convert trigger_inputs to a set for easier manipulation
|
||||
active_triggers = set(trigger_inputs)
|
||||
|
||||
Args:
|
||||
active_inputs: Dict of input names to values that are "present"
|
||||
|
||||
Returns:
|
||||
OrderedDict of block_name -> block that would execute
|
||||
"""
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_inputs):
|
||||
def fn_recursive_traverse(block, block_name, active_triggers):
|
||||
result_blocks = OrderedDict()
|
||||
|
||||
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
||||
if isinstance(block, ConditionalPipelineBlocks):
|
||||
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
|
||||
selected_block_name = block.select_block(**trigger_kwargs)
|
||||
|
||||
if selected_block_name is None:
|
||||
selected_block_name = block.default_block_name
|
||||
|
||||
if selected_block_name is None:
|
||||
return result_blocks
|
||||
|
||||
selected_block = block.sub_blocks[selected_block_name]
|
||||
|
||||
if selected_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
|
||||
# sequential(include loopsequential) or PipelineBlock
|
||||
if not hasattr(block, "block_trigger_inputs"):
|
||||
if block.sub_blocks:
|
||||
# sequential or LoopSequentialPipelineBlocks (keep traversing)
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
else:
|
||||
result_blocks[block_name] = selected_block
|
||||
if hasattr(selected_block, "outputs"):
|
||||
for out in selected_block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = block
|
||||
# Add this block's output names to active triggers if defined
|
||||
if hasattr(block, "outputs"):
|
||||
active_triggers.update(out.name for out in block.outputs)
|
||||
return result_blocks
|
||||
|
||||
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
if block.sub_blocks:
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
# auto
|
||||
else:
|
||||
result_blocks[block_name] = block
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
# Find first block_trigger_input that matches any value in our active_triggers
|
||||
this_block = None
|
||||
for trigger_input in block.block_trigger_inputs:
|
||||
if trigger_input is not None and trigger_input in active_triggers:
|
||||
this_block = block.trigger_to_block_map[trigger_input]
|
||||
break
|
||||
|
||||
# If no matches found, try to get the default (None) block
|
||||
if this_block is None and None in block.block_trigger_inputs:
|
||||
this_block = block.trigger_to_block_map[None]
|
||||
|
||||
if this_block is not None:
|
||||
# sequential/auto (keep traversing)
|
||||
if this_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
|
||||
else:
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = this_block
|
||||
# Add this block's output names to active triggers if defined
|
||||
# YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
|
||||
if hasattr(this_block, "outputs"):
|
||||
active_triggers.update(out.name for out in this_block.outputs)
|
||||
|
||||
return result_blocks
|
||||
|
||||
all_blocks = OrderedDict()
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
|
||||
all_blocks.update(blocks_to_update)
|
||||
return all_blocks
|
||||
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
def get_execution_blocks(self, *trigger_inputs):
|
||||
trigger_inputs_all = self.trigger_inputs
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
Pass any inputs that would be non-None at runtime.
|
||||
if trigger_inputs is not None:
|
||||
if not isinstance(trigger_inputs, (list, tuple, set)):
|
||||
trigger_inputs = [trigger_inputs]
|
||||
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
|
||||
if invalid_inputs:
|
||||
logger.warning(
|
||||
f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
|
||||
)
|
||||
trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
|
||||
|
||||
Returns:
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
|
||||
Example:
|
||||
# Get blocks for inpainting workflow
|
||||
blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask, image=image)
|
||||
|
||||
# Get blocks for text2image workflow
|
||||
blocks = pipeline.get_execution_blocks(prompt="a cat")
|
||||
"""
|
||||
# Filter out None values
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
|
||||
if trigger_inputs is None:
|
||||
if None in trigger_inputs_all:
|
||||
trigger_inputs = [None]
|
||||
else:
|
||||
trigger_inputs = [trigger_inputs_all[0]]
|
||||
blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1096,7 +1067,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
# Get first trigger input as example
|
||||
example_input = next(t for t in self.trigger_inputs if t is not None)
|
||||
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
||||
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -1120,9 +1091,22 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# Blocks section - moved to the end with simplified format
|
||||
blocks_str = " Sub-Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.sub_blocks.items()):
|
||||
|
||||
# show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
# Get trigger input for this block
|
||||
trigger = None
|
||||
if hasattr(self, "block_to_trigger_map"):
|
||||
trigger = self.block_to_trigger_map.get(name)
|
||||
# Format the trigger info
|
||||
if trigger is None:
|
||||
trigger_str = "[default]"
|
||||
elif isinstance(trigger, (list, tuple)):
|
||||
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
|
||||
else:
|
||||
trigger_str = f"[trigger: {trigger}]"
|
||||
# For AutoPipelineBlocks, add bullet points
|
||||
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
|
||||
else:
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
desc_lines = block.description.split("\n")
|
||||
@@ -1246,9 +1230,15 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
if inp.name not in outputs and inp not in inputs:
|
||||
inputs.append(inp)
|
||||
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
for input_param in inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
@@ -1305,14 +1295,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
# Validate that sub_blocks are only leaf blocks
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if block.sub_blocks:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). "
|
||||
f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
|
||||
"""
|
||||
|
||||
@@ -21,16 +21,21 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_qwenimage"] = [
|
||||
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit"] = [
|
||||
"CONTROLNET_BLOCKS",
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
|
||||
"EDIT_BLOCKS",
|
||||
"EDIT_INPAINT_BLOCKS",
|
||||
"EDIT_PLUS_AUTO_BLOCKS",
|
||||
"EDIT_PLUS_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
@@ -46,16 +51,23 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_qwenimage import (
|
||||
from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit import (
|
||||
CONTROLNET_BLOCKS,
|
||||
EDIT_AUTO_BLOCKS,
|
||||
QwenImageEditAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit_plus import (
|
||||
EDIT_BLOCKS,
|
||||
EDIT_INPAINT_BLOCKS,
|
||||
EDIT_PLUS_AUTO_BLOCKS,
|
||||
EDIT_PLUS_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import (
|
||||
@@ -74,4 +86,4 @@ else:
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -639,65 +639,19 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""RoPE inputs step for Edit Plus that handles lists of image heights/widths."""
|
||||
|
||||
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
|
||||
"Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
|
||||
"Should be placed after prepare_latents step."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True, type_hint=List[int]),
|
||||
InputParam(name="image_width", required=True, type_hint=List[int]),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the image latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
|
||||
# Edit Plus: image_height and image_width are lists
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
|
||||
*[
|
||||
(1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
|
||||
for img_height, img_width in zip(block_state.image_height, block_state.image_width)
|
||||
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
|
||||
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
|
||||
],
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
@@ -244,19 +244,18 @@ def encode_vae_image(
|
||||
class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_name: str = "image",
|
||||
output_name: str = "resized_image",
|
||||
target_area: int = 1024 * 1024,
|
||||
):
|
||||
"""Create a configurable step for resizing images to the target area while maintaining the aspect ratio.
|
||||
def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
|
||||
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
|
||||
This block resizes an input image tensor and exposes the resized result under configurable input and output
|
||||
names. Use this when you need to wire the resize step to different image fields (e.g., "image",
|
||||
"control_image")
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the image field to read from the
|
||||
pipeline state. Defaults to "image".
|
||||
output_name (str, optional): Name of the resized image field to write
|
||||
back to the pipeline state. Defaults to "resized_image".
|
||||
target_area (int, optional): Target area in pixels. Defaults to 1024*1024.
|
||||
"""
|
||||
if not isinstance(input_name, str) or not isinstance(output_name, str):
|
||||
raise ValueError(
|
||||
@@ -264,12 +263,11 @@ class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
|
||||
)
|
||||
self._image_input_name = input_name
|
||||
self._resized_image_output_name = output_name
|
||||
self._target_area = target_area
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"Image Resize step that resize the {self._image_input_name} to the target area {self._target_area} while maintaining the aspect ratio."
|
||||
return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -322,67 +320,48 @@ class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
class QwenImageEditPlusResizeDynamicStep(ModularPipelineBlocks):
|
||||
"""Resize each image independently based on its own aspect ratio. For QwenImage Edit Plus."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_name: str = "image",
|
||||
self,
|
||||
input_name: str = "image",
|
||||
output_name: str = "resized_image",
|
||||
target_area: int = 1024 * 1024,
|
||||
vae_image_output_name: str = "vae_image",
|
||||
):
|
||||
"""Create a step for resizing images to a target area.
|
||||
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
|
||||
|
||||
Each image is resized independently based on its own aspect ratio.
|
||||
This is suitable for Edit Plus where multiple reference images can have different dimensions.
|
||||
This block resizes an input image or a list input images and exposes the resized result under configurable
|
||||
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
|
||||
"image", "control_image")
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the image field to read. Defaults to "image".
|
||||
output_name (str, optional): Name of the resized image field to write. Defaults to "resized_image".
|
||||
target_area (int, optional): Target area in pixels. Defaults to 1024*1024.
|
||||
input_name (str, optional): Name of the image field to read from the
|
||||
pipeline state. Defaults to "image".
|
||||
output_name (str, optional): Name of the resized image field to write
|
||||
back to the pipeline state. Defaults to "resized_image".
|
||||
vae_image_output_name (str, optional): Name of the image field
|
||||
to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
|
||||
processes the input image(s) differently for the VL and the VAE.
|
||||
"""
|
||||
if not isinstance(input_name, str) or not isinstance(output_name, str):
|
||||
raise ValueError(
|
||||
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
|
||||
)
|
||||
self.condition_image_size = 384 * 384
|
||||
self._image_input_name = input_name
|
||||
self._resized_image_output_name = output_name
|
||||
self._target_area = target_area
|
||||
self._vae_image_output_name = vae_image_output_name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
f"Image Resize step that resizes {self._image_input_name} to target area {self._target_area}.\n"
|
||||
"Each image is resized independently based on its own aspect ratio."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_resize_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image(s) to resize"
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
return super().intermediate_outputs + [
|
||||
OutputParam(
|
||||
name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
|
||||
name=self._vae_image_output_name,
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images to be processed which will be further used by the VAE encoder.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -395,21 +374,26 @@ class QwenImageEditPlusResizeDynamicStep(ModularPipelineBlocks):
|
||||
if not is_valid_image_imagelist(images):
|
||||
raise ValueError(f"Images must be image or list of images but are {type(images)}")
|
||||
|
||||
if is_valid_image(images):
|
||||
if (
|
||||
not isinstance(images, torch.Tensor)
|
||||
and isinstance(images, PIL.Image.Image)
|
||||
and not isinstance(images, list)
|
||||
):
|
||||
images = [images]
|
||||
|
||||
# Resize each image independently based on its own aspect ratio
|
||||
resized_images = []
|
||||
for image in images:
|
||||
image_width, image_height = image.size
|
||||
calculated_width, calculated_height, _ = calculate_dimensions(
|
||||
self._target_area, image_width / image_height
|
||||
)
|
||||
resized_images.append(
|
||||
components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
|
||||
# TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
|
||||
condition_images = []
|
||||
vae_images = []
|
||||
for img in images:
|
||||
image_width, image_height = img.size
|
||||
condition_width, condition_height, _ = calculate_dimensions(
|
||||
self.condition_image_size, image_width / image_height
|
||||
)
|
||||
condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
|
||||
vae_images.append(img)
|
||||
|
||||
setattr(block_state, self._resized_image_output_name, resized_images)
|
||||
setattr(block_state, self._resized_image_output_name, condition_images)
|
||||
setattr(block_state, self._vae_image_output_name, vae_images)
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -663,30 +647,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
|
||||
"""Text encoder for QwenImage Edit Plus that handles multiple reference images."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together "
|
||||
"to generate text embeddings for guiding image generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
|
||||
ComponentSpec("processor", Qwen2VLProcessor),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
@@ -702,60 +664,6 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
|
||||
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam(
|
||||
name="resized_cond_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The encoder attention mask",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings mask",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, negative_prompt):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if (
|
||||
negative_prompt is not None
|
||||
and not isinstance(negative_prompt, str)
|
||||
and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
@@ -768,7 +676,7 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
prompt=block_state.prompt,
|
||||
image=block_state.resized_cond_image,
|
||||
image=block_state.resized_image,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
img_template_encode=components.config.img_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
@@ -784,7 +692,7 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
prompt=negative_prompt,
|
||||
image=block_state.resized_cond_image,
|
||||
image=block_state.resized_image,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
img_template_encode=components.config.img_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
@@ -938,60 +846,60 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(self):
|
||||
self.vae_image_size = 1024 * 1024
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("resized_image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam(name="processed_image")]
|
||||
return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.vae_image is None and block_state.image is None:
|
||||
raise ValueError("`vae_image` and `image` cannot be None at the same time")
|
||||
|
||||
|
||||
image = block_state.resized_image
|
||||
|
||||
is_image_list = isinstance(image, list)
|
||||
if not is_image_list:
|
||||
image = [image]
|
||||
|
||||
processed_images = []
|
||||
for img in image:
|
||||
img_width, img_height = img.size
|
||||
processed_images.append(components.image_processor.preprocess(image=img, height=img_height, width=img_width))
|
||||
block_state.processed_image = processed_images
|
||||
if is_image_list:
|
||||
block_state.processed_image = processed_images
|
||||
vae_image_sizes = None
|
||||
if block_state.vae_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
|
||||
)
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
image=image, height=height, width=width
|
||||
)
|
||||
else:
|
||||
block_state.processed_image = processed_images[0]
|
||||
# QwenImage Edit Plus can allow multiple input images with varied resolutions
|
||||
processed_images = []
|
||||
vae_image_sizes = []
|
||||
for img in block_state.vae_image:
|
||||
width, height = img.size
|
||||
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
|
||||
vae_image_sizes.append((vae_width, vae_height))
|
||||
processed_images.append(
|
||||
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
|
||||
)
|
||||
block_state.processed_image = processed_images
|
||||
|
||||
block_state.vae_image_sizes = vae_image_sizes
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
"""VAE encoder that handles both single images and lists of images with varied resolutions."""
|
||||
|
||||
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
@@ -1001,12 +909,21 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
):
|
||||
"""Initialize a VAE encoder step for converting images to latent representations.
|
||||
|
||||
Handles both single images and lists of images. When input is a list, outputs a list of latents.
|
||||
When input is a single tensor, outputs a single latent tensor.
|
||||
Both the input and output names are configurable so this block can be configured to process to different image
|
||||
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the input image tensor or list. Defaults to "processed_image".
|
||||
output_name (str, optional): Name of the output latent tensor or list. Defaults to "image_latents".
|
||||
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
|
||||
Examples: "processed_image" or "processed_control_image"
|
||||
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
|
||||
Examples: "image_latents" or "control_image_latents"
|
||||
|
||||
Examples:
|
||||
# Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
|
||||
|
||||
# Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
|
||||
input_name="processed_control_image", output_name="control_image_latents"
|
||||
)
|
||||
"""
|
||||
self._image_input_name = input_name
|
||||
self._image_latents_output_name = output_name
|
||||
@@ -1014,18 +931,17 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
|
||||
"Handles both single images and lists of images with varied resolutions."
|
||||
)
|
||||
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKLQwenImage)]
|
||||
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
@@ -1033,7 +949,7 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
self._image_latents_output_name,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image(s). Single tensor or list depending on input.",
|
||||
description="The latents representing the reference image",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1045,11 +961,47 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
is_image_list = isinstance(image, list)
|
||||
if not is_image_list:
|
||||
image = [image]
|
||||
|
||||
# Handle both single image and list of images
|
||||
# Encode image into latents
|
||||
image_latents = encode_vae_image(
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
# Each reference image latent can have varied resolutions hence we return this as a list.
|
||||
return [
|
||||
OutputParam(
|
||||
self._image_latents_output_name,
|
||||
type_hint=List[torch.Tensor],
|
||||
description="The latents representing the reference image(s).",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = []
|
||||
for img in image:
|
||||
image_latents.append(
|
||||
@@ -1062,12 +1014,9 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
)
|
||||
if not is_image_list:
|
||||
image_latents = image_latents[0]
|
||||
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -222,15 +222,36 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage: update height/width, expand batch, patchify."""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
|
||||
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to []. Examples: ["processed_mask_image"]
|
||||
|
||||
Examples:
|
||||
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
@@ -242,12 +263,14 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
@@ -256,16 +279,11 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
@@ -275,9 +293,11 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
@@ -290,16 +310,22 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents and update if not provided
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
@@ -309,7 +335,7 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify
|
||||
# 2. Patchify the image latent tensor
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
@@ -328,6 +354,7 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
@@ -341,130 +368,63 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusInputsDynamicStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
|
||||
|
||||
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Input processing step for Edit Plus that:\n"
|
||||
" 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size\n"
|
||||
" Height/width defaults to last image in the list."
|
||||
)
|
||||
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=List[int], description="The heights of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=List[int], description="The widths of the image latents"),
|
||||
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
is_list = isinstance(image_latent_tensor, list)
|
||||
if not is_list:
|
||||
image_latent_tensor = [image_latent_tensor]
|
||||
|
||||
# Each image latent can have different size in QwenImage Edit Plus.
|
||||
image_heights = []
|
||||
image_widths = []
|
||||
packed_image_latent_tensors = []
|
||||
|
||||
for i, img_latent_tensor in enumerate(image_latent_tensor):
|
||||
for img_latent_tensor in image_latent_tensor:
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
|
||||
image_heights.append(height)
|
||||
image_widths.append(width)
|
||||
|
||||
# 2. Patchify
|
||||
# 2. Patchify the image latent tensor
|
||||
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
img_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=f"{image_latent_input_name}[{i}]",
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=img_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
packed_image_latent_tensors.append(img_latent_tensor)
|
||||
|
||||
# Concatenate all packed latents along dim=1
|
||||
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
|
||||
|
||||
# Output lists of heights/widths
|
||||
block_state.image_height = image_heights
|
||||
block_state.image_width = image_widths
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Default height/width from last image
|
||||
block_state.height = block_state.height or image_heights[-1]
|
||||
block_state.width = block_state.width or image_widths[-1]
|
||||
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
@@ -476,6 +436,8 @@ class QwenImageEditPlusInputsDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
1113
src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
Normal file
1113
src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,465 +0,0 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks, ConditionalPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageRoPEInputsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
QwenImageInpaintControlNetDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageControlNetVaeEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
QwenImageProcessImagesInputStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageControlNetInputsStep,
|
||||
QwenImageInputsDynamicStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
# 1. VAE ENCODER
|
||||
|
||||
# inpaint vae encoder
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderDynamicStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
|
||||
" - Resizes the image to the target size, based on `height` and `width`.\n"
|
||||
" - Processes and updates `image` and `mask_image`.\n"
|
||||
" - Creates `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# img2img vae encoder
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderDynamicStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# auto vae encoder
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# optional controlnet vae encoder
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
|
||||
+ " - if `control_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
# 2. DENOISE
|
||||
# input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise
|
||||
|
||||
# img2img input
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# inpaint input
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"])]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
# inpaint prepare latents
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the pachified latents `mask` based on the processedmask image.\n"
|
||||
)
|
||||
|
||||
# CoreDenoiseStep:
|
||||
# (input + prepare_latents + set_timesteps + prepare_rope_inputs + denoise + after_denoise)
|
||||
|
||||
# 1. text2image
|
||||
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
|
||||
# 2.inpaint
|
||||
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageInpaintPrepareLatentsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageInpaintDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# 3. img2img
|
||||
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImagePrepareLatentsWithStrengthStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_img2img_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
|
||||
# 4. text2image + controlnet
|
||||
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
|
||||
# 5. inpaint + controlnet
|
||||
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageInpaintPrepareLatentsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageInpaintControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# 6. img2img + controlnet
|
||||
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImagePrepareLatentsWithStrengthStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_img2img_latents",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# auto denoise
|
||||
# auto denoise step for controlnet tasks: works for all tasks with controlnet
|
||||
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageCoreDenoiseStep,
|
||||
QwenImageInpaintCoreDenoiseStep,
|
||||
QwenImageImg2ImgCoreDenoiseStep,
|
||||
QwenImageControlNetCoreDenoiseStep,
|
||||
QwenImageControlNetInpaintCoreDenoiseStep,
|
||||
QwenImageControlNetImg2ImgCoreDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"text2image",
|
||||
"inpaint",
|
||||
"img2img",
|
||||
"controlnet_text2image",
|
||||
"controlnet_inpaint",
|
||||
"controlnet_img2img"]
|
||||
block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"]
|
||||
default_block_name = "text2image"
|
||||
|
||||
def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None):
|
||||
|
||||
if control_image_latents is not None:
|
||||
if processed_mask_image is not None:
|
||||
return "controlnet_inpaint"
|
||||
elif image_latents is not None:
|
||||
return "controlnet_img2img"
|
||||
else:
|
||||
return "controlnet_text2image"
|
||||
else:
|
||||
if processed_mask_image is not None:
|
||||
return "inpaint"
|
||||
elif image_latents is not None:
|
||||
return "img2img"
|
||||
else:
|
||||
return "text2image"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n"
|
||||
+ " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n"
|
||||
+ " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n"
|
||||
+ " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n"
|
||||
+ " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n"
|
||||
+ " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n"
|
||||
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
|
||||
+ " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
)
|
||||
|
||||
|
||||
# 4. DECODE
|
||||
|
||||
## 1.1 text2image
|
||||
|
||||
#### decode
|
||||
#### (standard decode step works for most tasks except for inpaint)
|
||||
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
|
||||
#### inpaint decode
|
||||
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
|
||||
|
||||
|
||||
# auto decode step for inpaint and text2image tasks
|
||||
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
|
||||
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
|
||||
## 1.10 QwenImage/auto block & presets
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("denoise", QwenImageAutoCoreDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
@@ -1,329 +0,0 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageEditDenoiseStep,
|
||||
QwenImageEditInpaintDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditResizeDynamicStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
QwenImageProcessImagesInputStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageInputsDynamicStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts."""
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeDynamicStep(),
|
||||
QwenImageEditTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit VL encoder step that encode the image and text prompts together."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
# Edit VAE encoder
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeDynamicStep(),
|
||||
QwenImageProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderDynamicStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Edit Inpaint VAE encoder
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeDynamicStep(),
|
||||
QwenImageInpaintProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
|
||||
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
|
||||
" - process the resized image and mask image.\n"
|
||||
" - create image latents."
|
||||
)
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block.\n"
|
||||
" - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
|
||||
" - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
|
||||
" - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE - input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise
|
||||
# ====================
|
||||
|
||||
# Edit input step
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# Edit Inpaint input step
|
||||
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit inpaint denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# Edit Inpaint prepare latents step
|
||||
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the patchified latents `mask` based on the processed mask image.\n"
|
||||
)
|
||||
|
||||
|
||||
# 1. Edit (img2img) core denoise
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageEditRoPEInputsStep(),
|
||||
QwenImageEditDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
|
||||
|
||||
|
||||
# 2. Edit Inpaint core denoise
|
||||
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageEditInpaintPrepareLatentsStep(),
|
||||
QwenImageEditRoPEInputsStep(),
|
||||
QwenImageEditInpaintDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
|
||||
|
||||
|
||||
# Auto core denoise step
|
||||
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageEditInpaintCoreDenoiseStep,
|
||||
QwenImageEditCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
default_block_name = "edit"
|
||||
|
||||
def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]:
|
||||
if processed_mask_image is not None:
|
||||
return "edit_inpaint"
|
||||
elif image_latents is not None:
|
||||
return "edit"
|
||||
return None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto core denoising step that selects the appropriate workflow based on inputs.\n"
|
||||
" - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n"
|
||||
" - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n"
|
||||
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
# Decode step (standard)
|
||||
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image."
|
||||
|
||||
|
||||
# Auto decode step
|
||||
class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images.\n"
|
||||
"This is an auto pipeline block.\n"
|
||||
" - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
|
||||
("denoise", QwenImageEditAutoCoreDenoiseStep()),
|
||||
("decode", QwenImageEditAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
|
||||
"- for edit (img2img) generation, you need to provide `image`\n"
|
||||
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
|
||||
)
|
||||
@@ -1,175 +0,0 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageEditDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditPlusResizeDynamicStep,
|
||||
QwenImageEditPlusTextEncoderStep,
|
||||
QwenImageEditPlusProcessImagesInputStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageEditPlusInputsDynamicStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeDynamicStep(target_area=384 * 384, output_name="resized_cond_image"),
|
||||
QwenImageEditPlusTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeDynamicStep(target_area=1024 * 1024, output_name="resized_image"),
|
||||
QwenImageEditPlusProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderDynamicStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"VAE encoder step that encodes image inputs into latent representations.\n"
|
||||
"Each image is resized independently based on its own aspect ratio to 1024x1024 target area."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE - input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise
|
||||
# ====================
|
||||
|
||||
# Edit Plus input step
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the Edit Plus denoising step. It:\n"
|
||||
" - Standardizes text embeddings batch size.\n"
|
||||
" - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n"
|
||||
" - Outputs lists of image_height/image_width for RoPE calculation.\n"
|
||||
" - Defaults height/width from last image in the list."
|
||||
)
|
||||
|
||||
|
||||
# Edit Plus core denoise
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageEditPlusRoPEInputsStep(),
|
||||
QwenImageEditDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocesses the generated image."
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
||||
("denoise", QwenImageEditPlusCoreDenoiseStep()),
|
||||
("decode", QwenImageEditPlusDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n"
|
||||
"- `image` is required input (can be single image or list of images).\n"
|
||||
"- Each image is resized independently based on its own aspect ratio.\n"
|
||||
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
|
||||
)
|
||||
@@ -411,6 +411,7 @@ else:
|
||||
"ZImagePipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -856,6 +857,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaPipeline
|
||||
from .lumina2 import Lumina2Pipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
@@ -119,7 +120,13 @@ from .stable_diffusion_xl import (
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
|
||||
from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
@@ -164,6 +171,10 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
("z-image", ZImagePipeline),
|
||||
("z-image-controlnet", ZImageControlNetPipeline),
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -229,7 +229,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
],
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -133,7 +133,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> # export_to_video(video, "image2world.mp4", fps=16)
|
||||
>>> export_to_video(video, "image2world.mp4", fps=16)
|
||||
|
||||
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
|
||||
>>> prompt = (
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -204,7 +204,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -208,7 +208,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -41,7 +42,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||
|
||||
from .pipeline_z_image_omni import ZImageOmniPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
742
src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Normal file
742
src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel
|
||||
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..flux2.image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageOmniPipeline
|
||||
|
||||
>>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... height=1024,
|
||||
... width=1024,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
siglip: Siglip2VisionModel,
|
||||
siglip_processor: Siglip2ImageProcessorFast,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
siglip=siglip,
|
||||
siglip_processor=siglip_processor,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
# self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
num_condition_images: int = 0,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
num_condition_images: int = 0,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
if num_condition_images == 0:
|
||||
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
||||
elif num_condition_images > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
||||
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
prompt[i] = prompt_list
|
||||
|
||||
flattened_prompt = []
|
||||
prompt_list_lengths = []
|
||||
|
||||
for i in range(len(prompt)):
|
||||
prompt_list_lengths.append(len(prompt[i]))
|
||||
flattened_prompt.extend(prompt[i])
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
flattened_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
start_idx = 0
|
||||
for i in range(len(prompt_list_lengths)):
|
||||
batch_embeddings = []
|
||||
end_idx = start_idx + prompt_list_lengths[i]
|
||||
for j in range(start_idx, end_idx):
|
||||
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
||||
embeddings_list.append(batch_embeddings)
|
||||
start_idx = end_idx
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latent = (
|
||||
self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor
|
||||
) * self.vae.config.scaling_factor
|
||||
image_latent = image_latent.unsqueeze(1).to(dtype)
|
||||
image_latents.append(image_latent) # (16, 128, 128)
|
||||
|
||||
# image_latents = [image_latents] * batch_size
|
||||
image_latents = [image_latents.copy() for _ in range(batch_size)]
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_siglip_embeds(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
siglip_embeds = []
|
||||
for image in images:
|
||||
siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device)
|
||||
shape = siglip_inputs.spatial_shapes[0]
|
||||
hidden_state = self.siglip(**siglip_inputs).last_hidden_state
|
||||
B, N, C = hidden_state.shape
|
||||
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
||||
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
||||
siglip_embeds.append(hidden_state.to(dtype))
|
||||
|
||||
# siglip_embeds = [siglip_embeds] * batch_size
|
||||
siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)]
|
||||
|
||||
return siglip_embeds
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
num_condition_images = len(image) if image is not None else 0
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
|
||||
# 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2
|
||||
condition_images = []
|
||||
resized_images = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
if height is not None and width is not None:
|
||||
img = self.image_processor._resize_to_target_area(img, height * width)
|
||||
else:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
resized_images.append(img)
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
|
||||
if len(condition_images) > 0:
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
else:
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
condition_latents = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents]
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents]
|
||||
|
||||
condition_siglip_embeds = self.prepare_siglip_embeds(
|
||||
images=resized_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds]
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds]
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds]
|
||||
negative_condition_siglip_embeds = [
|
||||
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds
|
||||
]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
condition_latents_model_input = condition_latents + negative_condition_latents
|
||||
condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
condition_latents_model_input = condition_latents
|
||||
condition_siglip_embeds_model_input = condition_siglip_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
# Combine condition latents with target latent
|
||||
current_batch_size = len(latent_model_input_list)
|
||||
x_combined = [
|
||||
condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size)
|
||||
]
|
||||
# Create noise mask: 0 for condition images (clean), 1 for target image (noisy)
|
||||
image_noise_mask = [
|
||||
[0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size)
|
||||
]
|
||||
|
||||
model_out_list = self.transformer(
|
||||
x=x_combined,
|
||||
t=timestep_model_input,
|
||||
cap_feats=prompt_embeds_model_input,
|
||||
siglip_feats=condition_siglip_embeds_model_input,
|
||||
image_noise_mask=image_noise_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -3917,6 +3917,21 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageOmniPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user