mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-28 14:35:00 +08:00
Compare commits
4 Commits
wan-layerw
...
modular-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e77f1c56dc | ||
|
|
372222a4b6 | ||
|
|
1f57b175ae | ||
|
|
581a425130 |
8
.github/workflows/build_docker_images.yml
vendored
8
.github/workflows/build_docker_images.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
if: github.event_name == 'pull_request'
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v6
|
||||
@@ -101,14 +101,14 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.REGISTRY }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
no-cache: true
|
||||
context: ./docker/${{ matrix.image-name }}
|
||||
|
||||
20
.github/workflows/pr_modular_tests.yml
vendored
20
.github/workflows/pr_modular_tests.yml
vendored
@@ -75,27 +75,9 @@ jobs:
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
check_auto_docs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check auto docs
|
||||
run: make modular-autodoctrings
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency, check_auto_docs]
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Fast PyTorch Modular Pipeline CPU tests
|
||||
|
||||
runs-on:
|
||||
|
||||
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -11,4 +11,4 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.42.1
|
||||
uses: crate-ci/typos@v1.12.4
|
||||
|
||||
4
Makefile
4
Makefile
@@ -70,10 +70,6 @@ fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
|
||||
# Auto docstrings in modular blocks
|
||||
modular-autodoctrings:
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
test:
|
||||
|
||||
@@ -413,9 +413,6 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
@@ -1149,9 +1146,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
|
||||
@@ -152,10 +152,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanAnimateTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
|
||||
@@ -136,7 +136,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"wan_vace": "vace_blocks.0.after_proj.bias",
|
||||
"wan_animate": "motion_encoder.dec.direction.weight",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
"cosmos-1.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
@@ -220,7 +219,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
|
||||
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
||||
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
@@ -761,9 +759,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif checkpoint[target_key].shape[0] == 5120:
|
||||
model_type = "wan-vace-14B"
|
||||
|
||||
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
|
||||
model_type = "wan-animate-14B"
|
||||
|
||||
elif checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-t2v-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
@@ -3159,64 +3154,13 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def generate_motion_encoder_mappings():
|
||||
mappings = {
|
||||
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
||||
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
|
||||
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
|
||||
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
|
||||
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
||||
}
|
||||
|
||||
for i in range(7):
|
||||
conv_idx = i + 1
|
||||
mappings.update(
|
||||
{
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
|
||||
}
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
def generate_face_adapter_mappings():
|
||||
return {
|
||||
"face_adapter.fuser_blocks": "face_adapter",
|
||||
".k_norm.": ".norm_k.",
|
||||
".q_norm.": ".norm_q.",
|
||||
".linear1_q.": ".to_q.",
|
||||
".linear2.": ".to_out.",
|
||||
"conv1_local.conv": "conv1_local",
|
||||
"conv2.conv": "conv2",
|
||||
"conv3.conv": "conv3",
|
||||
}
|
||||
|
||||
def split_tensor_handler(key, state_dict, split_pattern, target_keys):
|
||||
tensor = state_dict.pop(key)
|
||||
split_idx = tensor.shape[0] // 2
|
||||
|
||||
new_key_1 = key.replace(split_pattern, target_keys[0])
|
||||
new_key_2 = key.replace(split_pattern, target_keys[1])
|
||||
|
||||
state_dict[new_key_1] = tensor[:split_idx]
|
||||
state_dict[new_key_2] = tensor[split_idx:]
|
||||
|
||||
def reshape_bias_handler(key, state_dict):
|
||||
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
|
||||
state_dict[key] = state_dict[key][0, :, 0, 0]
|
||||
|
||||
converted_state_dict = {}
|
||||
|
||||
# Strip model.diffusion_model prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Base transformer mappings
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
@@ -3238,42 +3182,27 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# I2V model
|
||||
# For the I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# VACE model
|
||||
# For the VACE model
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
SPECIAL_KEYS_HANDLERS = {}
|
||||
if any("face_adapter" in k for k in checkpoint.keys()):
|
||||
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
|
||||
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])
|
||||
|
||||
if any("motion_encoder" in k for k in checkpoint.keys()):
|
||||
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
reshape_bias_handler(key, checkpoint)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = checkpoint.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
|
||||
if pattern not in key:
|
||||
continue
|
||||
handler_fn(key, converted_state_dict, pattern, target_keys)
|
||||
break
|
||||
converted_state_dict[new_key] = checkpoint.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Optional, Union
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
|
||||
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
|
||||
load_id = "|".join("null" if p is None else p for p in parts)
|
||||
model._diffusers_load_id = load_id
|
||||
|
||||
return model
|
||||
|
||||
@@ -143,86 +143,41 @@ class GlmImageAdaLayerNormZero(nn.Module):
|
||||
|
||||
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model.
|
||||
Supports per-sample caching for batch processing where each sample may have different condition images.
|
||||
"""
|
||||
"""KV cache for GlmImage model."""
|
||||
|
||||
def __init__(self):
|
||||
self.k_caches: List[Optional[torch.Tensor]] = []
|
||||
self.v_caches: List[Optional[torch.Tensor]] = []
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
self.current_sample_idx: int = 0 # Current sample index for writing
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Store KV cache for the current sample."""
|
||||
# k, v shape: (1, seq_len, num_heads, head_dim)
|
||||
if len(self.k_caches) <= self.current_sample_idx:
|
||||
# First time storing for this sample
|
||||
self.k_caches.append(k)
|
||||
self.v_caches.append(v)
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
else:
|
||||
# Append to existing cache for this sample (multiple condition images)
|
||||
self.k_caches[self.current_sample_idx] = torch.cat([self.k_caches[self.current_sample_idx], k], dim=1)
|
||||
self.v_caches[self.current_sample_idx] = torch.cat([self.v_caches[self.current_sample_idx], v], dim=1)
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=1)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=1)
|
||||
|
||||
def get(self, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Get combined KV cache for all samples in the batch.
|
||||
|
||||
Args:
|
||||
k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim)
|
||||
v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim)
|
||||
Returns:
|
||||
Combined key and value tensors with cached values prepended.
|
||||
"""
|
||||
batch_size = k.shape[0]
|
||||
num_cached_samples = len(self.k_caches)
|
||||
if num_cached_samples == 0:
|
||||
return k, v
|
||||
if num_cached_samples == 1:
|
||||
# Single cache, expand for all batch samples (shared condition images)
|
||||
k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1)
|
||||
v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1)
|
||||
elif num_cached_samples == batch_size:
|
||||
# Per-sample cache, concatenate along batch dimension
|
||||
k_cache_expanded = torch.cat(self.k_caches, dim=0)
|
||||
v_cache_expanded = torch.cat(self.v_caches, dim=0)
|
||||
if self.k_cache.shape[0] != k.shape[0]:
|
||||
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
|
||||
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
|
||||
else:
|
||||
# Mismatch: try to handle by repeating the caches
|
||||
# This handles cases like num_images_per_prompt > 1
|
||||
repeat_factor = batch_size // num_cached_samples
|
||||
if batch_size % num_cached_samples == 0:
|
||||
k_cache_list = []
|
||||
v_cache_list = []
|
||||
for i in range(num_cached_samples):
|
||||
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
|
||||
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
|
||||
k_cache_expanded = torch.cat(k_cache_list, dim=0)
|
||||
v_cache_expanded = torch.cat(v_cache_list, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. "
|
||||
f"Batch size must be a multiple of the number of cached samples."
|
||||
)
|
||||
k_cache_expanded = self.k_cache
|
||||
v_cache_expanded = self.v_cache
|
||||
|
||||
k_combined = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_combined = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_combined, v_combined
|
||||
k_cache = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_cache = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_cache, v_cache
|
||||
|
||||
def clear(self):
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode = None
|
||||
self.current_sample_idx = 0
|
||||
|
||||
def next_sample(self):
|
||||
"""Move to the next sample for writing."""
|
||||
self.current_sample_idx += 1
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches.
|
||||
Supports per-sample caching for batch processing where each sample may have different condition images.
|
||||
"""
|
||||
"""Container for all layers' KV caches."""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
@@ -237,12 +192,6 @@ class GlmImageKVCache:
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def next_sample(self):
|
||||
"""Move to the next sample for writing. Call this after processing
|
||||
all condition images for one batch sample."""
|
||||
for cache in self.caches:
|
||||
cache.next_sample()
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
@@ -166,11 +166,9 @@ class MotionConv2d(nn.Module):
|
||||
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
|
||||
# set to 1, which should be equivalent to a 2D convolution
|
||||
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
|
||||
x = x.to(expanded_kernel.dtype)
|
||||
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
|
||||
|
||||
# Main Conv2D with scaling
|
||||
x = x.to(self.weight.dtype)
|
||||
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
# Activation with fused bias, if using
|
||||
@@ -771,7 +769,7 @@ class WanImageEmbedding(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Modified from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -805,12 +803,10 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
if self.time_embedder.linear_1.weight.dtype.is_floating_point:
|
||||
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
|
||||
else:
|
||||
time_embedder_dtype = encoder_hidden_states.dtype
|
||||
|
||||
temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states)
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
|
||||
@@ -54,10 +54,7 @@ else:
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -84,13 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .flux2 import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2TextInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2"] = [
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"REMOTE_AUTO_BLOCKS",
|
||||
@@ -51,11 +51,10 @@ else:
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2CoreDenoiseStep",
|
||||
"Flux2BeforeDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -86,7 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks_flux2 import (
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
@@ -94,14 +93,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2CoreDenoiseStep,
|
||||
Flux2BeforeDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_blocks_flux2_klein import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -129,9 +129,17 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -143,12 +151,13 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
@@ -174,7 +183,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
block_state.device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
@@ -182,6 +191,11 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -339,6 +353,7 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="latent_ids"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -350,6 +365,12 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latent_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -382,72 +403,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="negative_prompt_embeds", required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
block_state.negative_txt_ids = None
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
|
||||
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -551,42 +506,3 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the guidance scale tensor for Flux2 inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,16 +29,29 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that unpacks the latents from the denoising step"
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -57,9 +70,9 @@ class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoise latents from denoising step, unpacked with position IDs.",
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -94,62 +107,6 @@ class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
block_state.latents = latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
"""Convert patchified latents back to regular format."""
|
||||
@@ -164,20 +121,26 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
latents = block_state.latents
|
||||
if block_state.output_type == "latent":
|
||||
block_state.images = block_state.latents
|
||||
else:
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents = self._unpatchify_latents(latents)
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -16,8 +16,6 @@ from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
@@ -27,8 +25,8 @@ from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -136,229 +134,6 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# same as Flux2LoopDenoiser but guidance=None
|
||||
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# support CFG for Flux2-Klein base model
|
||||
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"negative_txt_ids",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for negative text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"txt_ids": (
|
||||
getattr(block_state, "txt_ids", None),
|
||||
getattr(block_state, "negative_txt_ids", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
# perform guidance
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -475,35 +250,3 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinBaseLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
@@ -15,15 +15,13 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -81,8 +79,10 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -99,7 +99,14 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -158,6 +165,10 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -194,6 +205,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -210,8 +222,15 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -225,6 +244,10 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -247,289 +270,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
else:
|
||||
block_state.negative_prompt_embeds = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -89,90 +89,6 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -12,22 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
Flux2PrepareGuidanceStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
@@ -47,6 +41,7 @@ Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -77,56 +72,33 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
Flux2CoreDenoiseBlocks = InsertableDict(
|
||||
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2CoreDenoiseBlocks.values()
|
||||
block_names = Flux2CoreDenoiseBlocks.keys()
|
||||
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-dev.\n"
|
||||
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -135,8 +107,10 @@ AUTO_BLOCKS = InsertableDict(
|
||||
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -156,16 +130,6 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
@@ -173,10 +137,8 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -190,10 +152,8 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -1,232 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
|
||||
from .encoders import (
|
||||
Flux2KleinBaseTextEncoderStep,
|
||||
Flux2KleinTextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2KleinBaseTextInputStep,
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
################
|
||||
# VAE encoder
|
||||
################
|
||||
|
||||
Flux2KleinVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2KleinVaeEncoderBlocks.values()
|
||||
block_names = Flux2KleinVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
|
||||
|
||||
|
||||
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2KleinVaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
###
|
||||
### Core denoise
|
||||
###
|
||||
|
||||
Flux2KleinCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2KleinDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = Flux2KleinCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
###
|
||||
### Auto blocks
|
||||
###
|
||||
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinBaseTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinBaseCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -57,56 +55,3 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
@@ -59,7 +59,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
@@ -2143,6 +2142,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
and self._component_specs[name].pretrained_model_name_or_path is not None
|
||||
and getattr(self, name, None) is None
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
@@ -15,15 +15,14 @@
|
||||
import inspect
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import is_torch_available, logging
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -186,7 +185,7 @@ class ComponentSpec:
|
||||
"""
|
||||
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
return DIFFUSERS_LOAD_ID_FIELDS.copy()
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
@@ -198,7 +197,7 @@ class ComponentSpec:
|
||||
return "null"
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
return "|".join(parts)
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
@@ -324,192 +323,11 @@ class ConfigSpec:
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# ======================================================
|
||||
# InputParam and OutputParam templates
|
||||
# ======================================================
|
||||
|
||||
INPUT_PARAM_TEMPLATES = {
|
||||
"prompt": {
|
||||
"type_hint": str,
|
||||
"required": True,
|
||||
"description": "The prompt or prompts to guide image generation.",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type_hint": str,
|
||||
"description": "The prompt or prompts not to guide the image generation.",
|
||||
},
|
||||
"max_sequence_length": {
|
||||
"type_hint": int,
|
||||
"default": 512,
|
||||
"description": "Maximum sequence length for prompt encoding.",
|
||||
},
|
||||
"height": {
|
||||
"type_hint": int,
|
||||
"description": "The height in pixels of the generated image.",
|
||||
},
|
||||
"width": {
|
||||
"type_hint": int,
|
||||
"description": "The width in pixels of the generated image.",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"type_hint": int,
|
||||
"default": 50,
|
||||
"description": "The number of denoising steps.",
|
||||
},
|
||||
"num_images_per_prompt": {
|
||||
"type_hint": int,
|
||||
"default": 1,
|
||||
"description": "The number of images to generate per prompt.",
|
||||
},
|
||||
"generator": {
|
||||
"type_hint": torch.Generator,
|
||||
"description": "Torch generator for deterministic generation.",
|
||||
},
|
||||
"sigmas": {
|
||||
"type_hint": List[float],
|
||||
"description": "Custom sigmas for the denoising process.",
|
||||
},
|
||||
"strength": {
|
||||
"type_hint": float,
|
||||
"default": 0.9,
|
||||
"description": "Strength for img2img/inpainting.",
|
||||
},
|
||||
"image": {
|
||||
"type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
"required": True,
|
||||
"description": "Reference image(s) for denoising. Can be a single image or list of images.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Pre-generated noisy latents for image generation.",
|
||||
},
|
||||
"timesteps": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Timesteps for the denoising process.",
|
||||
},
|
||||
"output_type": {
|
||||
"type_hint": str,
|
||||
"default": "pil",
|
||||
"description": "Output format: 'pil', 'np', 'pt'.",
|
||||
},
|
||||
"attention_kwargs": {
|
||||
"type_hint": Dict[str, Any],
|
||||
"description": "Additional kwargs for attention processors.",
|
||||
},
|
||||
"denoiser_input_fields": {
|
||||
"name": None,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
},
|
||||
# inpainting
|
||||
"mask_image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Mask image for inpainting.",
|
||||
},
|
||||
"padding_mask_crop": {
|
||||
"type_hint": int,
|
||||
"description": "Padding for mask cropping in inpainting.",
|
||||
},
|
||||
# controlnet
|
||||
"control_image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Control image for ControlNet conditioning.",
|
||||
},
|
||||
"control_guidance_start": {
|
||||
"type_hint": float,
|
||||
"default": 0.0,
|
||||
"description": "When to start applying ControlNet.",
|
||||
},
|
||||
"control_guidance_end": {
|
||||
"type_hint": float,
|
||||
"default": 1.0,
|
||||
"description": "When to stop applying ControlNet.",
|
||||
},
|
||||
"controlnet_conditioning_scale": {
|
||||
"type_hint": float,
|
||||
"default": 1.0,
|
||||
"description": "Scale for ControlNet conditioning.",
|
||||
},
|
||||
"layers": {
|
||||
"type_hint": int,
|
||||
"default": 4,
|
||||
"description": "Number of layers to extract from the image",
|
||||
},
|
||||
# common intermediate inputs
|
||||
"prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"required": True,
|
||||
"description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.",
|
||||
},
|
||||
"prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"required": True,
|
||||
"description": "mask for the text embeddings. Can be generated from text_encoder step.",
|
||||
},
|
||||
"negative_prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.",
|
||||
},
|
||||
"negative_prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "mask for the negative text embeddings. Can be generated from text_encoder step.",
|
||||
},
|
||||
"image_latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"required": True,
|
||||
"description": "image latents used to guide the image generation. Can be generated from vae_encoder step.",
|
||||
},
|
||||
"batch_size": {
|
||||
"type_hint": int,
|
||||
"default": 1,
|
||||
"description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
},
|
||||
"dtype": {
|
||||
"type_hint": torch.dtype,
|
||||
"default": torch.float32,
|
||||
"description": "The dtype of the model inputs, can be generated in input step.",
|
||||
},
|
||||
}
|
||||
|
||||
OUTPUT_PARAM_TEMPLATES = {
|
||||
"images": {
|
||||
"type_hint": List[PIL.Image.Image],
|
||||
"description": "Generated images.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Denoised latents.",
|
||||
},
|
||||
# intermediate outputs
|
||||
"prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The prompt embeddings.",
|
||||
},
|
||||
"prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The encoder attention mask.",
|
||||
},
|
||||
"negative_prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The negative prompt embeddings.",
|
||||
},
|
||||
"negative_prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The negative prompt embeddings mask.",
|
||||
},
|
||||
"image_latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "The latent representation of the input image.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
|
||||
# however some fields are not relevant for intermediate_inputs
|
||||
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
|
||||
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
|
||||
# -> should we use different class for inputs and intermediate_inputs?
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
@@ -519,31 +337,11 @@ class InputParam:
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
@classmethod
|
||||
def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam":
|
||||
"""Get template for name if exists, otherwise raise ValueError."""
|
||||
if template_name not in INPUT_PARAM_TEMPLATES:
|
||||
raise ValueError(f"InputParam template for {template_name} not found")
|
||||
|
||||
template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy()
|
||||
|
||||
# Determine the actual param name:
|
||||
# 1. From overrides if provided
|
||||
# 2. From template if present
|
||||
# 3. Fall back to template_name
|
||||
name = overrides.pop("name", template_kwargs.pop("name", template_name))
|
||||
|
||||
if note and "description" in template_kwargs:
|
||||
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
|
||||
|
||||
template_kwargs.update(overrides)
|
||||
return cls(name=name, **template_kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
@@ -552,33 +350,13 @@ class OutputParam:
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam":
|
||||
"""Get template for name if exists, otherwise raise ValueError."""
|
||||
if template_name not in OUTPUT_PARAM_TEMPLATES:
|
||||
raise ValueError(f"OutputParam template for {template_name} not found")
|
||||
|
||||
template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy()
|
||||
|
||||
# Determine the actual param name:
|
||||
# 1. From overrides if provided
|
||||
# 2. From template if present
|
||||
# 3. Fall back to template_name
|
||||
name = overrides.pop("name", template_kwargs.pop("name", template_name))
|
||||
|
||||
if note and "description" in template_kwargs:
|
||||
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
|
||||
|
||||
template_kwargs.update(overrides)
|
||||
return cls(name=name, **template_kwargs)
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
"""
|
||||
@@ -731,12 +509,10 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
|
||||
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
||||
param_str += f"\n{desc_indent}{wrapped_desc}"
|
||||
else:
|
||||
param_str += f"\n{desc_indent}TODO: Add description."
|
||||
|
||||
formatted_params.append(param_str)
|
||||
|
||||
return "\n".join(formatted_params)
|
||||
return "\n\n".join(formatted_params)
|
||||
|
||||
|
||||
def format_input_params(input_params, indent_level=4, max_line_length=115):
|
||||
@@ -806,7 +582,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
|
||||
loading_field_values = []
|
||||
for field_name in component.loading_fields():
|
||||
field_value = getattr(component, field_name)
|
||||
if field_value:
|
||||
if field_value is not None:
|
||||
loading_field_values.append(f"{field_name}={field_value}")
|
||||
|
||||
# Add loading field information if available
|
||||
@@ -893,17 +669,17 @@ def make_doc_string(
|
||||
# Add description
|
||||
if description:
|
||||
desc_lines = description.strip().split("\n")
|
||||
aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines)
|
||||
aligned_desc = "\n".join(" " + line for line in desc_lines)
|
||||
output += aligned_desc + "\n\n"
|
||||
|
||||
# Add components section if provided
|
||||
if expected_components and len(expected_components) > 0:
|
||||
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
|
||||
components_str = format_components(expected_components, indent_level=2)
|
||||
output += components_str + "\n\n"
|
||||
|
||||
# Add configs section if provided
|
||||
if expected_configs and len(expected_configs) > 0:
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
configs_str = format_configs(expected_configs, indent_level=2)
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
|
||||
@@ -118,40 +118,7 @@ def get_timesteps(scheduler, num_inference_steps, strength):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Prepare initial random noise for the generation process
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
height (`int`):
|
||||
if not set, updated to default value
|
||||
width (`int`):
|
||||
if not set, updated to default value
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -167,20 +134,28 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("latents"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("generator"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("dtype"),
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -234,42 +209,7 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Prepare initial random noise (B, layers+1, C, H, W) for the generation process
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
height (`int`):
|
||||
if not set, updated to default value
|
||||
width (`int`):
|
||||
if not set, updated to default value
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -285,21 +225,29 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("latents"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam.template("layers"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("generator"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("dtype"),
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="layers", default=4),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -353,31 +301,7 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps,
|
||||
prepare_latents. Both noise and image latents should alreadybe patchified.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
|
||||
generated from vae encoder and updated in input step.)
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
|
||||
Outputs:
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
latents (`Tensor`):
|
||||
The scaled noisy latents to use for inpainting/image-to-image denoising.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -399,7 +323,12 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised, can be generated in prepare latent step.",
|
||||
),
|
||||
InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."),
|
||||
InputParam(
|
||||
name="image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
required=True,
|
||||
@@ -416,11 +345,6 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised used for inpainting denoising.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The scaled noisy latents to use for inpainting/image-to-image denoising.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -458,29 +382,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that creates mask latents from preprocessed mask_image by interpolating to latent space.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -502,9 +404,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The processed mask to use for the inpainting process.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("dtype"),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="dtype", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -548,27 +450,7 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`):
|
||||
The initial random noised latents for the denoising process. Can be generated in prepare latents step.
|
||||
|
||||
Outputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -584,13 +466,13 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.",
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -634,27 +516,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -670,17 +532,15 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam.template("image_latents"),
|
||||
InputParam("num_inference_steps", default=50, type_hint=int),
|
||||
InputParam("sigmas", type_hint=List[float]),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process."
|
||||
),
|
||||
OutputParam(name="timesteps", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -714,32 +574,7 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare
|
||||
latents step.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`):
|
||||
The latents to use for the denoising process. Can be generated in prepare latents step.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
|
||||
Outputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time. Updated based on strength.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -755,15 +590,15 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam(
|
||||
"latents",
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare latents step.",
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
InputParam.template("strength", default=0.9),
|
||||
InputParam(name="strength", default=0.9),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -772,12 +607,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
name="timesteps",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process.",
|
||||
),
|
||||
OutputParam(
|
||||
name="num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time. Updated based on strength.",
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -824,29 +654,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
## RoPE inputs for denoiser
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the images latents, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -858,11 +666,11 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -894,34 +702,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after
|
||||
prepare_latents step
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_height (`int`):
|
||||
The height of the reference image. Can be generated in input step.
|
||||
image_width (`int`):
|
||||
The width of the reference image. Can be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the images latents, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -931,23 +712,13 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("batch_size"),
|
||||
InputParam(
|
||||
name="image_height",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The height of the reference image. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_width",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The width of the reference image. Can be generated in input step.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True),
|
||||
InputParam(name="image_width", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -985,39 +756,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.
|
||||
Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed
|
||||
after prepare_latents step.
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_height (`List`):
|
||||
The heights of the reference images. Can be generated in input step.
|
||||
image_width (`List`):
|
||||
The widths of the reference images. Can be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the image latents, used for RoPE calculation
|
||||
txt_seq_lens (`List`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation
|
||||
negative_txt_seq_lens (`List`):
|
||||
The sequence lengths of the negative prompt embeds, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
@@ -1031,23 +770,13 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("batch_size"),
|
||||
InputParam(
|
||||
name="image_height",
|
||||
required=True,
|
||||
type_hint=List[int],
|
||||
description="The heights of the reference images. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_width",
|
||||
required=True,
|
||||
type_hint=List[int],
|
||||
description="The widths of the reference images. Can be generated in input step.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True, type_hint=List[int]),
|
||||
InputParam(name="image_width", required=True, type_hint=List[int]),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1103,37 +832,7 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the image latents, used for RoPE calculation
|
||||
txt_seq_lens (`List`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation
|
||||
negative_txt_seq_lens (`List`):
|
||||
The sequence lengths of the negative prompt embeds, used for RoPE calculation
|
||||
additional_t_cond (`Tensor`):
|
||||
The additional t cond, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -1145,12 +844,12 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("layers"),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="layers", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1215,34 +914,7 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
"""
|
||||
step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step.
|
||||
|
||||
Components:
|
||||
controlnet (`QwenImageControlNetModel`)
|
||||
|
||||
Inputs:
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
|
||||
Outputs:
|
||||
controlnet_keep (`List`):
|
||||
The controlnet keep values
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -1258,17 +930,12 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("control_guidance_start"),
|
||||
InputParam.template("control_guidance_end"),
|
||||
InputParam.template("controlnet_conditioning_scale"),
|
||||
InputParam("control_guidance_start", default=0.0),
|
||||
InputParam("control_guidance_end", default=1.0),
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("control_image_latents", required=True),
|
||||
InputParam(
|
||||
name="control_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -29,30 +31,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# after denoising loop (unpack latents)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size,
|
||||
channels, 1, height, width)
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
latents (`Tensor`):
|
||||
The latents to decode, can be generated in the denoise step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The denoisedlatents unpacked to B, C, 1, H, W
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -70,21 +49,13 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W"
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -101,29 +72,7 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents. (unpacked to B, C, layers+1, H, W)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -139,21 +88,10 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents to decode, can be generated in the denoise step.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("layers"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"),
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("height", required=True, type_hint=int),
|
||||
InputParam("width", required=True, type_hint=int),
|
||||
InputParam("layers", required=True, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -174,26 +112,7 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
# decode step
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that decodes the latents to images
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -215,13 +134,19 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.",
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam.template("images", note="tensor output of the vae decoder.")]
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -251,26 +176,7 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Decode unpacked latents (B, C, layers+1, H, W) into layer images.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -292,19 +198,14 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.",
|
||||
),
|
||||
InputParam.template("output_type"),
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -350,27 +251,7 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
# postprocess the decoded images
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
"""
|
||||
postprocess the generated image
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
images (`Tensor`):
|
||||
the generated image tensor from decoders step
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -391,19 +272,15 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="images",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="the generated image tensor from decoders step",
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
InputParam.template("output_type"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam.template("images")]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
@@ -424,28 +301,7 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
"""
|
||||
postprocess the generated image, optional apply the mask overally to the original image..
|
||||
|
||||
Components:
|
||||
image_mask_processor (`InpaintProcessor`)
|
||||
|
||||
Inputs:
|
||||
images (`Tensor`):
|
||||
the generated image tensor from decoders step
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -466,24 +322,16 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="images",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="the generated image tensor from decoders step",
|
||||
),
|
||||
InputParam.template("output_type"),
|
||||
InputParam(
|
||||
name="mask_overlay_kwargs",
|
||||
type_hint=Dict[str, Any],
|
||||
description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.",
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
InputParam("mask_overlay_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam.template("images")]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type, mask_overlay_kwargs):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
|
||||
@@ -50,7 +50,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="latents",
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
@@ -80,12 +80,17 @@ class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="latents",
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam.template("image_latents"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -129,12 +134,29 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."),
|
||||
InputParam(
|
||||
name="controlnet_keep",
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=float,
|
||||
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_keep",
|
||||
required=True,
|
||||
type_hint=List[float],
|
||||
description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.",
|
||||
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -195,13 +217,28 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("attention_kwargs"),
|
||||
InputParam.template("denoiser_input_fields"),
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
type_hint=List[Tuple[int, int]],
|
||||
description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.",
|
||||
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -280,8 +317,23 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("attention_kwargs"),
|
||||
InputParam.template("denoiser_input_fields"),
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
@@ -363,7 +415,7 @@ class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -404,19 +456,24 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam.template("image_latents"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"initial_noise",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -458,12 +515,17 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam.template("num_inference_steps", required=True),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -495,42 +557,7 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image (text2image, image2image)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports text2image and image2image tasks for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [
|
||||
@@ -543,8 +570,8 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents.\n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n"
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
@@ -554,47 +581,7 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image (inpainting)
|
||||
# auto_docstring
|
||||
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
- `QwenImageLoopAfterDenoiserInpaint`
|
||||
This block supports inpainting tasks for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
initial_noise (`Tensor`):
|
||||
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
@@ -619,47 +606,7 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image (text2image, image2image) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopBeforeDenoiserControlNet`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports text2img/img2img tasks with controlnet for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer
|
||||
(`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
|
||||
controlnet_keep (`List`):
|
||||
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
@@ -684,54 +631,7 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopBeforeDenoiserControlNet`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
- `QwenImageLoopAfterDenoiserInpaint`
|
||||
This block supports inpainting tasks with controlnet for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer
|
||||
(`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
|
||||
controlnet_keep (`List`):
|
||||
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
initial_noise (`Tensor`):
|
||||
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
@@ -764,42 +664,7 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image)
|
||||
# auto_docstring
|
||||
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageEditLoopBeforeDenoiser`
|
||||
- `QwenImageEditLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports QwenImage Edit.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
@@ -822,47 +687,7 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting)
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageEditLoopBeforeDenoiser`
|
||||
- `QwenImageEditLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
- `QwenImageLoopAfterDenoiserInpaint`
|
||||
This block supports inpainting tasks for QwenImage Edit.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
initial_noise (`Tensor`):
|
||||
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
@@ -887,42 +712,7 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image)
|
||||
# auto_docstring
|
||||
class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageEditLoopBeforeDenoiser`
|
||||
- `QwenImageEditLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports QwenImage Layered.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -109,44 +109,7 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in
|
||||
return height, width
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Text input processing step that standardizes text embeddings for the pipeline.
|
||||
This step:
|
||||
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
|
||||
2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)
|
||||
|
||||
This block should be placed after all encoder steps to process the text embeddings before they are used in
|
||||
subsequent pipeline steps.
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -166,22 +129,26 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("prompt_embeds"),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"),
|
||||
OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"),
|
||||
OutputParam.template("prompt_embeds", note="batch-expanded"),
|
||||
OutputParam.template("prompt_embeds_mask", note="batch-expanded"),
|
||||
OutputParam.template("negative_prompt_embeds", note="batch-expanded"),
|
||||
OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"),
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -254,76 +221,20 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Input processing step that:
|
||||
1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size
|
||||
2. For additional batch inputs: Expands batch dimensions to match final batch size
|
||||
|
||||
Configured inputs:
|
||||
- Image latent inputs: ['image_latents']
|
||||
|
||||
This block should be placed after the encoder steps and the text input step.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
"""
|
||||
"""Input step for QwenImage: update height/width, expand batch, patchify."""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: Optional[List[InputParam]] = None,
|
||||
additional_batch_inputs: Optional[List[InputParam]] = None,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
# by default, process `image_latents`
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
@@ -341,9 +252,9 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
@@ -358,19 +269,23 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
# default is `image_latents`
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
outputs = [
|
||||
return [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
@@ -383,43 +298,11 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
),
|
||||
]
|
||||
|
||||
# `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided
|
||||
if len(self._image_latent_inputs) > 0:
|
||||
outputs.append(
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
|
||||
)
|
||||
outputs.append(
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
|
||||
)
|
||||
|
||||
# image latent inputs are modified in place (patchified and batch-expanded)
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (patchified and batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
# additional batch inputs (batch-expanded only)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_input_name = input_param.name
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
@@ -448,8 +331,7 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_name = input_param.name
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
@@ -467,76 +349,20 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Input processing step for Edit Plus that:
|
||||
1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch
|
||||
2. For additional batch inputs: Expands batch dimensions to match final batch size
|
||||
Height/width defaults to last image in the list.
|
||||
|
||||
Configured inputs:
|
||||
- Image latent inputs: ['image_latents']
|
||||
|
||||
This block should be placed after the encoder steps and the text input step.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`List`):
|
||||
The image heights calculated from the image latents dimension
|
||||
image_width (`List`):
|
||||
The image widths calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified,
|
||||
concatenated, and batch-expanded)
|
||||
"""
|
||||
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: Optional[List[InputParam]] = None,
|
||||
additional_batch_inputs: Optional[List[InputParam]] = None,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
@@ -555,9 +381,9 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
@@ -572,20 +398,23 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# default is `image_latents`
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
outputs = [
|
||||
return [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=List[int],
|
||||
@@ -598,43 +427,11 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
),
|
||||
]
|
||||
|
||||
# `height`/`width` are updated if any image latent inputs are provided
|
||||
if len(self._image_latent_inputs) > 0:
|
||||
outputs.append(
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
|
||||
)
|
||||
outputs.append(
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
|
||||
)
|
||||
|
||||
# image latent inputs are modified in place (patchified, concatenated, and batch-expanded)
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (patchified, concatenated, and batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
# additional batch inputs (batch-expanded only)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_input_name = input_param.name
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
@@ -679,8 +476,7 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_name = input_param.name
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
@@ -698,75 +494,22 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# same as QwenImageAdditionalInputsStep, but with layered pachifier.
|
||||
|
||||
|
||||
# auto_docstring
|
||||
# YiYi TODO: support define config default component from the ModularPipeline level.
|
||||
# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
|
||||
class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Input processing step for Layered that:
|
||||
1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch
|
||||
size
|
||||
2. For additional batch inputs: Expands batch dimensions to match final batch size
|
||||
|
||||
Configured inputs:
|
||||
- Image latent inputs: ['image_latents']
|
||||
|
||||
This block should be placed after the encoder steps and the text input step.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified
|
||||
with layered pachifier and batch-expanded)
|
||||
"""
|
||||
"""Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: Optional[List[InputParam]] = None,
|
||||
additional_batch_inputs: Optional[List[InputParam]] = None,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
@@ -784,9 +527,9 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
@@ -801,18 +544,21 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
]
|
||||
# default is `image_latents`
|
||||
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
outputs = [
|
||||
return [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
@@ -823,44 +569,15 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The image width calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(name="height", type_hint=int, description="The height of the image output"),
|
||||
OutputParam(name="width", type_hint=int, description="The width of the image output"),
|
||||
]
|
||||
|
||||
if len(self._image_latent_inputs) > 0:
|
||||
outputs.append(
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
|
||||
)
|
||||
outputs.append(
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
|
||||
)
|
||||
|
||||
# Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded)
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (patchified with layered pachifier and batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
# Add outputs for additional batch inputs (batch-expanded only)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_input_name = input_param.name
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
@@ -891,8 +608,7 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_name = input_param.name
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
@@ -910,34 +626,7 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps.
|
||||
|
||||
Inputs:
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
Outputs:
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents (patchified and batch-expanded).
|
||||
height (`int`):
|
||||
if not provided, updated to control image height
|
||||
width (`int`):
|
||||
if not provided, updated to control image width
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -947,28 +636,11 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="control_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.",
|
||||
),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="control_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image latents (patchified and batch-expanded).",
|
||||
),
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"),
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"),
|
||||
InputParam(name="control_image_latents", required=True),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -12,11 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
@@ -56,91 +59,11 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAutoTextEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block.
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextEncoderStep()]
|
||||
block_names = ["text_encoder"]
|
||||
block_trigger_inputs = ["prompt"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block."
|
||||
" - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided."
|
||||
" - if `prompt` is not provided, step will be skipped."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step is used for processing image and mask inputs for inpainting tasks. It:
|
||||
- Resizes the image to the target size, based on `height` and `width`.
|
||||
- Processes and updates `image` and `mask_image`.
|
||||
- Creates `image_latents`.
|
||||
|
||||
Components:
|
||||
image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
@@ -155,31 +78,7 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that preprocess andencode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
@@ -190,6 +89,7 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
@@ -207,33 +107,7 @@ class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# optional controlnet vae encoder
|
||||
# auto_docstring
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
This is an auto pipeline block.
|
||||
- `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.
|
||||
- if `control_image` is not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor
|
||||
(`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
control_image (`Image`, *optional*):
|
||||
Control image for ControlNet conditioning.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
control_image_latents (`Tensor`):
|
||||
The latents representing the control image
|
||||
"""
|
||||
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
@@ -249,65 +123,14 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the img2img denoising step. It:
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()]
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
@@ -317,69 +140,12 @@ class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the inpainting denoising step. It:
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image (batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
additional_batch_inputs=[
|
||||
InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")
|
||||
]
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
@@ -392,42 +158,7 @@ class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
# auto_docstring
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:
|
||||
- Add noise to the image latents to create the latents input for the denoiser.
|
||||
- Create the pachified latents `mask` based on the processedmask image.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
|
||||
generated from vae encoder and updated in input step.)
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
latents (`Tensor`):
|
||||
The scaled noisy latents to use for inpainting/image-to-image denoising.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
@@ -445,49 +176,7 @@ class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image (text2image)
|
||||
# auto_docstring
|
||||
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs
|
||||
(timesteps, latents, rope inputs etc.).
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -510,63 +199,9 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (inpainting)
|
||||
# auto_docstring
|
||||
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
@@ -591,61 +226,9 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (image2image)
|
||||
# auto_docstring
|
||||
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
@@ -670,66 +253,9 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (text2image) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs
|
||||
(timesteps, latents, rope inputs etc.).
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
|
||||
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -756,72 +282,9 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
|
||||
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
@@ -850,70 +313,9 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (image2image) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
|
||||
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
@@ -942,12 +344,6 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Auto denoise step for QwenImage
|
||||
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@@ -1006,36 +402,19 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# 3. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
# standard decode step works for most tasks except for inpaint
|
||||
# auto_docstring
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -1046,30 +425,7 @@ class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
# auto_docstring
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask
|
||||
overally to the original image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -1096,11 +452,11 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageAutoTextEncoderStep()),
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("denoise", QwenImageAutoCoreDenoiseStep()),
|
||||
@@ -1109,89 +465,7 @@ AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
- for image-to-image generation, you need to provide `image`
|
||||
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
|
||||
- to run the controlnet workflow, you need to provide `control_image`
|
||||
- for text-to-image generation, all you need to provide is `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`)
|
||||
control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
mask_image (`Image`, *optional*):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
control_image (`Image`, *optional*):
|
||||
Control image for ControlNet conditioning.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
control_image_latents (`Tensor`, *optional*):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
@@ -1202,7 +476,7 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
@@ -1210,5 +484,5 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
]
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
@@ -58,35 +59,8 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
QwenImage-Edit VL encoder step that encode the image and text prompts together.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
"""VL encoder that takes both image and text prompts."""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
@@ -106,30 +80,7 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
@@ -144,46 +95,12 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit Inpaint VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:
|
||||
- resize the image for target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
- process the resized image and mask image.
|
||||
- create image latents.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditInpaintProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@@ -220,64 +137,11 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the edit denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(),
|
||||
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@@ -290,71 +154,12 @@ class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the edit inpaint denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image (batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
additional_batch_inputs=[
|
||||
InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")
|
||||
]
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
@@ -369,42 +174,7 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:
|
||||
- Add noise to the image latents to create the latents input for the denoiser.
|
||||
- Create the patchified latents `mask` based on the processed mask image.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
|
||||
generated from vae encoder and updated in input step.)
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
latents (`Tensor`):
|
||||
The scaled noisy latents to use for inpainting/image-to-image denoising.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
@@ -419,50 +189,7 @@ class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Edit edit (img2img) task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInputStep(),
|
||||
@@ -485,62 +212,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Edit edit inpaint task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintInputStep(),
|
||||
@@ -565,12 +239,6 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Auto core denoise step for QwenImage Edit
|
||||
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@@ -599,12 +267,6 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
@@ -612,26 +274,7 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
|
||||
|
||||
# Decode step (standard)
|
||||
# auto_docstring
|
||||
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -642,30 +285,7 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask
|
||||
overlay to the original image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -693,7 +313,9 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -711,66 +333,7 @@ EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.
|
||||
- for edit (img2img) generation, you need to provide `image`
|
||||
- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide
|
||||
`padding_mask_crop`
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
mask_image (`Image`, *optional*):
|
||||
Mask image for inpainting.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
@@ -786,5 +349,5 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -12,6 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
@@ -48,41 +53,12 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(),
|
||||
QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
|
||||
QwenImageEditPlusTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
@@ -97,36 +73,12 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that encodes image inputs into latent representations.
|
||||
Each image is resized independently based on its own aspect ratio to 1024x1024 target area.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(),
|
||||
QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
|
||||
QwenImageEditPlusProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
@@ -146,66 +98,11 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the Edit Plus denoising step. It:
|
||||
- Standardizes text embeddings batch size.
|
||||
- Processes list of image latents: patchifies, concatenates along dim=1, expands batch.
|
||||
- Outputs lists of image_height/image_width for RoPE calculation.
|
||||
- Defaults height/width from last image in the list.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`List`):
|
||||
The image heights calculated from the image latents dimension
|
||||
image_width (`List`):
|
||||
The image widths calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified,
|
||||
concatenated, and batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageEditPlusAdditionalInputsStep(),
|
||||
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@@ -221,50 +118,7 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Edit Plus (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Edit Plus edit (img2img) task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusInputStep(),
|
||||
@@ -290,7 +144,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -299,26 +155,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocesses the generated image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -342,53 +179,7 @@ EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.
|
||||
- `image` is required input (can be single image or list of images).
|
||||
- Each image is resized independently based on its own aspect ratio.
|
||||
- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
|
||||
@@ -405,5 +196,5 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -12,6 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
@@ -49,44 +55,8 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not
|
||||
provided.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
use_en_prompt (`bool`, *optional*, defaults to False):
|
||||
Whether to use English prompt template
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation. If not provided, updated using image caption
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
@@ -107,32 +77,7 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
@@ -153,60 +98,11 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the layered denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified
|
||||
with layered pachifier and batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageLayeredAdditionalInputsStep(),
|
||||
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@@ -220,48 +116,7 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Layered img2img task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredInputStep(),
|
||||
@@ -287,7 +142,9 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -305,54 +162,7 @@ LAYERED_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for layered denoising tasks using QwenImage-Layered.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`)
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
use_en_prompt (`bool`, *optional*, defaults to False):
|
||||
Whether to use English prompt template
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = LAYERED_AUTO_BLOCKS.values()
|
||||
block_names = LAYERED_AUTO_BLOCKS.keys()
|
||||
@@ -364,5 +174,5 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -131,7 +131,7 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
|
||||
@@ -482,6 +482,8 @@ class ChromaInpaintPipeline(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
max_sequence_length=None,
|
||||
@@ -529,6 +531,15 @@ class ChromaInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||
|
||||
@@ -782,11 +793,13 @@ class ChromaInpaintPipeline(
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
||||
@@ -52,15 +52,6 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
@@ -368,7 +359,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
@@ -558,7 +549,6 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
num_latent_conditional_frames: int = 2,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
@@ -624,10 +614,6 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
num_latent_conditional_frames (`int`, defaults to `2`):
|
||||
Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames
|
||||
extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1
|
||||
for Image2World-like behavior (single frame conditioning).
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -706,38 +692,19 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
if num_latent_conditional_frames not in [1, 2]:
|
||||
raise ValueError(
|
||||
f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}"
|
||||
)
|
||||
|
||||
frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1
|
||||
|
||||
total_input_frames = len(video)
|
||||
|
||||
if total_input_frames < frames_to_extract:
|
||||
raise ValueError(
|
||||
f"Input video has only {total_input_frames} frames but Video2World requires at least "
|
||||
f"{frames_to_extract} frames for conditioning."
|
||||
)
|
||||
|
||||
num_frames_in = frames_to_extract
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# For Video2World: extract last frames_to_extract frames from input, then pad
|
||||
if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]:
|
||||
video = video[:, :, -num_frames_in:, :, :]
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
num_frames_out = num_frames
|
||||
|
||||
if video.shape[2] < num_frames_out:
|
||||
n_pad_frames = num_frames_out - video.shape[2]
|
||||
last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W]
|
||||
n_pad_frames = num_frames_out - num_frames_in
|
||||
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
|
||||
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
|
||||
video = torch.cat((video, pad_frames), dim=2)
|
||||
|
||||
|
||||
@@ -49,14 +49,6 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -308,7 +300,7 @@ class Cosmos2TextToImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -50,14 +50,6 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -327,7 +319,7 @@ class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -49,14 +49,6 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -293,7 +285,7 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -50,14 +50,6 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -339,7 +331,7 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -260,115 +260,25 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
token_ids = token_ids.view(1, -1)
|
||||
return token_ids
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_normalize_images(
|
||||
image: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]],
|
||||
batch_size: int,
|
||||
) -> Optional[List[List[PIL.Image.Image]]]:
|
||||
"""
|
||||
Validate and normalize image inputs to List[List[PIL.Image]].
|
||||
|
||||
Rules:
|
||||
- batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length
|
||||
- batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]])
|
||||
- Other formats raise ValueError
|
||||
|
||||
Args:
|
||||
image: Input images in various formats
|
||||
batch_size: Number of prompts in the batch
|
||||
|
||||
Returns:
|
||||
Normalized images as List[List[PIL.Image]], or None if no images provided
|
||||
"""
|
||||
if image is None or len(image) == 0:
|
||||
return None
|
||||
|
||||
first_element = image[0]
|
||||
|
||||
if batch_size == 1:
|
||||
# Legacy format: List[PIL.Image] -> [[img1, img2, ...]]
|
||||
if not isinstance(first_element, (list, tuple)):
|
||||
return [list(image)]
|
||||
# Already in List[List[PIL.Image]] format
|
||||
if len(image) != 1:
|
||||
raise ValueError(
|
||||
f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}."
|
||||
)
|
||||
return [list(image[0])]
|
||||
|
||||
# batch_size > 1: must be List[List[PIL.Image]]
|
||||
if not isinstance(first_element, (list, tuple)):
|
||||
raise ValueError(
|
||||
f"For batch_size > 1, images must be List[List[PIL.Image]] format. "
|
||||
f"Got List[{type(first_element).__name__}] instead. "
|
||||
f"Each prompt requires its own list of condition images."
|
||||
)
|
||||
|
||||
if len(image) != batch_size:
|
||||
raise ValueError(f"Number of image lists ({len(image)}) must match batch size ({batch_size}).")
|
||||
|
||||
# Validate homogeneous: all sublists must have same length
|
||||
num_input_images_per_prompt = len(image[0])
|
||||
for idx, imgs in enumerate(image):
|
||||
if len(imgs) != num_input_images_per_prompt:
|
||||
raise ValueError(
|
||||
f"All prompts must have the same number of condition images. "
|
||||
f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images."
|
||||
)
|
||||
|
||||
return [list(imgs) for imgs in image]
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[List[PIL.Image.Image]]] = None,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
"""
|
||||
Generate prior tokens for the DiT model using the AR model.
|
||||
|
||||
Args:
|
||||
prompt: Single prompt or list of prompts
|
||||
height: Target image height
|
||||
width: Target image width
|
||||
image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated
|
||||
using _validate_and_normalize_images() before calling this method.
|
||||
device: Target device
|
||||
generator: Random generator for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens
|
||||
- prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains
|
||||
the upsampled prior token ids for all condition images in that sample. None for t2i.
|
||||
- source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape
|
||||
(num_condition_images, 3) with upsampled grid info. None for t2i.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# Normalize prompt to list format
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt_list)
|
||||
|
||||
# Image is already normalized by _validate_and_normalize_images(): None or List[List[PIL.Image]]
|
||||
is_text_to_image = image is None
|
||||
# Build messages for each sample in the batch
|
||||
all_messages = []
|
||||
for idx, p in enumerate(prompt_list):
|
||||
content = []
|
||||
if not is_text_to_image:
|
||||
for img in image[idx]:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": p})
|
||||
all_messages.append([{"role": "user", "content": content}])
|
||||
# Process with the processor (supports batch with left padding)
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
all_messages,
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True if batch_size > 1 else False,
|
||||
target_h=height,
|
||||
target_w=width,
|
||||
return_dict=True,
|
||||
@@ -376,117 +286,44 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
).to(device)
|
||||
|
||||
image_grid_thw = inputs.get("image_grid_thw")
|
||||
images_per_sample = inputs.get("images_per_sample")
|
||||
|
||||
# Determine number of condition images and grids per sample
|
||||
num_condition_images = 0 if is_text_to_image else len(image[0])
|
||||
if images_per_sample is not None:
|
||||
num_grids_per_sample = images_per_sample[0].item()
|
||||
else:
|
||||
# Fallback for batch_size=1: total grids is for single sample
|
||||
num_grids_per_sample = image_grid_thw.shape[0]
|
||||
|
||||
# Compute generation params (same for all samples in homogeneous batch)
|
||||
first_sample_grids = image_grid_thw[:num_grids_per_sample]
|
||||
max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
|
||||
image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image
|
||||
image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
|
||||
)
|
||||
|
||||
# Generate source image tokens (prior_token_image_ids) for i2i mode
|
||||
prior_token_image_ids = None
|
||||
source_image_grid_thw = None
|
||||
if not is_text_to_image:
|
||||
# Extract source grids by selecting condition image indices (skip target grids)
|
||||
# Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...]
|
||||
# We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...]
|
||||
source_indices = []
|
||||
for sample_idx in range(batch_size):
|
||||
base = sample_idx * num_grids_per_sample
|
||||
source_indices.extend(range(base, base + num_condition_images))
|
||||
source_grids = image_grid_thw[source_indices]
|
||||
if image is not None:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], image_grid_thw[:-1]
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, image_grid_thw[:-1]
|
||||
)
|
||||
|
||||
if len(source_grids) > 0:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], source_grids, return_dict=False
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, source_grids
|
||||
)
|
||||
# Upsample each source image's prior tokens to match VAE/DiT resolution
|
||||
split_sizes = source_grids.prod(dim=-1).tolist()
|
||||
prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes)
|
||||
upsampled_prior_ids = []
|
||||
for i, prior_ids in enumerate(prior_ids_per_source):
|
||||
t, h, w = source_grids[i].tolist()
|
||||
upsampled = self._upsample_token_ids(prior_ids, int(h), int(w))
|
||||
upsampled_prior_ids.append(upsampled.squeeze(0))
|
||||
prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0)
|
||||
# Upsample grid dimensions for later splitting
|
||||
upsampled_grids = source_grids.clone()
|
||||
upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2
|
||||
upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2
|
||||
source_image_grid_thw = upsampled_grids
|
||||
|
||||
# Generate with AR model
|
||||
# Set torch random seed from generator for reproducibility
|
||||
# (transformers generate() doesn't accept generator parameter)
|
||||
if generator is not None:
|
||||
seed = generator.initial_seed()
|
||||
torch.manual_seed(seed)
|
||||
if device is not None and device.type == "cuda":
|
||||
torch.cuda.manual_seed(seed)
|
||||
# For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
|
||||
# max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
|
||||
outputs = self.vision_language_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
# Extract and upsample prior tokens for each sample
|
||||
# For left-padded inputs, generated tokens start after the padded input sequence
|
||||
all_prior_token_ids = []
|
||||
max_input_length = inputs["input_ids"].shape[-1]
|
||||
for idx in range(batch_size):
|
||||
# For left-padded sequences, generated tokens start at max_input_length
|
||||
# (padding is on the left, so all sequences end at the same position)
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs[idx : idx + 1], max_input_length, large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
all_prior_token_ids.append(prior_token_ids)
|
||||
prior_token_ids = torch.cat(all_prior_token_ids, dim=0)
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
|
||||
# Split prior_token_image_ids and source_image_grid_thw into per-sample lists for easier consumption
|
||||
prior_token_image_ids_per_sample = None
|
||||
source_image_grid_thw_per_sample = None
|
||||
if prior_token_image_ids is not None and source_image_grid_thw is not None:
|
||||
# Split grids: each sample has num_condition_images grids
|
||||
source_image_grid_thw_per_sample = list(torch.split(source_image_grid_thw, num_condition_images))
|
||||
# Split prior_token_image_ids: tokens per sample may vary due to different image sizes
|
||||
tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist()
|
||||
tokens_per_sample = []
|
||||
for i in range(batch_size):
|
||||
start_idx = i * num_condition_images
|
||||
end_idx = start_idx + num_condition_images
|
||||
tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx]))
|
||||
prior_token_image_ids_per_sample = list(torch.split(prior_token_image_ids, tokens_per_sample))
|
||||
|
||||
return prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample
|
||||
return prior_token_ids, prior_token_image_ids
|
||||
|
||||
def get_glyph_texts(self, prompt):
|
||||
"""Extract glyph texts from prompt(s). Returns a list of lists for batch processing."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
all_ocr_texts = []
|
||||
for p in prompt:
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", p)
|
||||
+ re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p)
|
||||
+ re.findall(r'"([^"]*)"', p)
|
||||
+ re.findall(r"「([^「」]*)」", p)
|
||||
)
|
||||
all_ocr_texts.append(ocr_texts)
|
||||
return all_ocr_texts
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", prompt)
|
||||
+ re.findall(r"“([^“”]*)”", prompt)
|
||||
+ re.findall(r'"([^"]*)"', prompt)
|
||||
+ re.findall(r"「([^「」]*)」", prompt)
|
||||
)
|
||||
return ocr_texts
|
||||
|
||||
def _get_glyph_embeds(
|
||||
self,
|
||||
@@ -495,51 +332,29 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""Get glyph embeddings for each prompt in the batch."""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
# get_glyph_texts now returns a list of lists (one per prompt)
|
||||
all_glyph_texts = self.get_glyph_texts(prompt)
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts if len(glyph_texts) > 0 else [""],
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
|
||||
all_glyph_embeds = []
|
||||
for glyph_texts in all_glyph_texts:
|
||||
if len(glyph_texts) == 0:
|
||||
glyph_texts = [""]
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_))
|
||||
for input_ids_ in input_ids
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
all_glyph_embeds.append(glyph_embeds)
|
||||
|
||||
# Pad to same sequence length and stack (use left padding to match transformers)
|
||||
max_seq_len = max(emb.size(1) for emb in all_glyph_embeds)
|
||||
padded_embeds = []
|
||||
for emb in all_glyph_embeds:
|
||||
if emb.size(1) < max_seq_len:
|
||||
pad = torch.zeros(emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype)
|
||||
emb = torch.cat([pad, emb], dim=1) # left padding
|
||||
padded_embeds.append(emb)
|
||||
|
||||
glyph_embeds = torch.cat(padded_embeds, dim=0)
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
def encode_prompt(
|
||||
@@ -584,9 +399,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
# Repeat embeddings for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For GLM-Image, negative_prompt must be "" instead of None
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@@ -594,8 +409,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
if num_images_per_prompt > 1:
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@@ -626,9 +442,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prior_token_ids=None,
|
||||
prior_token_image_ids=None,
|
||||
source_image_grid_thw=None,
|
||||
image=None,
|
||||
prior_image_token_ids=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
@@ -674,24 +488,12 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
# Validate prior token inputs: for i2i mode, all three must be provided together
|
||||
# For t2i mode, only prior_token_ids is needed (prior_token_image_ids and source_image_grid_thw should be None)
|
||||
prior_image_inputs = [prior_token_image_ids, source_image_grid_thw]
|
||||
num_prior_image_inputs = sum(x is not None for x in prior_image_inputs)
|
||||
if num_prior_image_inputs > 0 and num_prior_image_inputs < len(prior_image_inputs):
|
||||
if (prior_token_ids is None and prior_image_token_ids is not None) or (
|
||||
prior_token_ids is not None and prior_image_token_ids is None
|
||||
):
|
||||
raise ValueError(
|
||||
"`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. "
|
||||
f"Got prior_token_image_ids={prior_token_image_ids is not None}, "
|
||||
f"source_image_grid_thw={source_image_grid_thw is not None}."
|
||||
)
|
||||
if num_prior_image_inputs > 0 and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided."
|
||||
)
|
||||
if num_prior_image_inputs > 0 and image is None:
|
||||
raise ValueError(
|
||||
"`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided "
|
||||
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
|
||||
f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
|
||||
f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
@@ -743,8 +545,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prior_token_ids: Optional[torch.FloatTensor] = None,
|
||||
prior_token_image_ids: Optional[List[torch.Tensor]] = None,
|
||||
source_image_grid_thw: Optional[List[torch.Tensor]] = None,
|
||||
prior_image_token_ids: Optional[torch.Tensor] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -797,9 +598,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prior_token_ids,
|
||||
prior_token_image_ids,
|
||||
source_image_grid_thw,
|
||||
image,
|
||||
prior_image_token_ids,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
@@ -812,47 +611,34 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Validate and normalize image format
|
||||
normalized_image = self._validate_and_normalize_images(image, batch_size)
|
||||
|
||||
# 3. Generate prior tokens (batch mode)
|
||||
# Get a single generator for AR model (use first if list provided)
|
||||
ar_generator = generator[0] if isinstance(generator, list) else generator
|
||||
# 2. Preprocess image tokens and prompt tokens
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample = (
|
||||
self.generate_prior_tokens(
|
||||
prompt=prompt,
|
||||
image=normalized_image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
generator=ar_generator,
|
||||
)
|
||||
prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# User provided prior_token_ids directly (from generate_prior_tokens)
|
||||
prior_token_image_ids_per_sample = prior_token_image_ids
|
||||
source_image_grid_thw_per_sample = source_image_grid_thw
|
||||
|
||||
# 4. Preprocess images for VAE encoding
|
||||
preprocessed_images = None
|
||||
if normalized_image is not None:
|
||||
preprocessed_images = []
|
||||
for prompt_images in normalized_image:
|
||||
prompt_preprocessed = []
|
||||
for img in prompt_images:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
prompt_preprocessed.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
preprocessed_images.append(prompt_preprocessed)
|
||||
# 3. Preprocess image
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
image = preprocessed_condition_images
|
||||
|
||||
# 5. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
@@ -866,7 +652,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare latents and (optional) image kv cache
|
||||
# 4. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
@@ -880,7 +666,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if normalized_image is not None:
|
||||
if image is not None:
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
@@ -888,38 +674,29 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
# Process each sample's condition images
|
||||
for prompt_idx in range(batch_size):
|
||||
prompt_images = preprocessed_images[prompt_idx]
|
||||
prompt_prior_ids = prior_token_image_ids_per_sample[prompt_idx]
|
||||
prompt_grid_thw = source_image_grid_thw_per_sample[prompt_idx]
|
||||
for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
|
||||
# Split this sample's prior_token_image_ids by each image's token count
|
||||
split_sizes = prompt_grid_thw.prod(dim=-1).tolist()
|
||||
prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes)
|
||||
# Process each condition image for this sample
|
||||
for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
# Do not remove.
|
||||
# It would be use to run the reference image through a
|
||||
# forward pass at timestep 0 and keep the KV cache.
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
# Move to next sample's cache slot
|
||||
kv_caches.next_sample()
|
||||
|
||||
# 7. Prepare additional timestep conditions
|
||||
# 6. Prepare additional timestep conditions
|
||||
target_size = (height, width)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
@@ -949,13 +726,10 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 8. Denoising loop
|
||||
# 7. Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# Repeat prior_token_ids for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
|
||||
prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -968,7 +742,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if prior_token_image_ids_per_sample is not None:
|
||||
if image is not None:
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
@@ -986,7 +760,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if prior_token_image_ids_per_sample is not None:
|
||||
if image is not None:
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
|
||||
@@ -262,9 +262,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -324,9 +324,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -305,9 +305,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -309,9 +309,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -321,9 +321,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
|
||||
|
||||
@@ -323,9 +323,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
||||
|
||||
@@ -305,9 +305,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -316,9 +316,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -328,9 +328,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):
|
||||
|
||||
@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`, *optional*):
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
||||
def previous_timestep(self, timestep: int) -> int:
|
||||
"""
|
||||
Compute the previous timestep in the diffusion chain.
|
||||
|
||||
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int` or `torch.Tensor`:
|
||||
`int`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -149,41 +149,38 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
For more details, see the original paper: https://huggingface.co/papers/2006.11239
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`, defaults to `"fixed_small"`):
|
||||
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Option to clip predicted sample for numerical stability.
|
||||
prediction_type (`str`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://huggingface.co/papers/2210.02303)
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
|
||||
diffusion models (such as stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, default `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
steps_offset (`int`, default `0`):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
@@ -296,7 +293,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`, *optional*):
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
@@ -481,7 +478,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
sample: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -493,8 +490,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
@@ -507,10 +503,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
||||
"learned",
|
||||
"learned_range",
|
||||
]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
@@ -559,10 +552,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
if t > 0:
|
||||
device = model_output.device
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=model_output.dtype,
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
)
|
||||
if self.variance_type == "fixed_small_log":
|
||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||
@@ -585,7 +575,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -598,8 +588,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`torch.Tensor`):
|
||||
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
|
||||
timesteps (`List[int]`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
@@ -613,10 +603,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
||||
"learned",
|
||||
"learned_range",
|
||||
]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
pass
|
||||
@@ -747,7 +734,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
||||
def previous_timestep(self, timestep):
|
||||
"""
|
||||
Compute the previous timestep in the diffusion chain.
|
||||
|
||||
@@ -756,7 +743,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int` or `torch.Tensor`:
|
||||
`int`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -22,7 +22,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,9 +31,6 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DPMSolverMultistepSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -175,10 +171,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
|
||||
@@ -211,10 +203,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple,
|
||||
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
) -> DPMSolverMultistepSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -312,13 +301,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://huggingface.co/papers/2205.11487
|
||||
dynamic_max_val = jnp.percentile(
|
||||
jnp.abs(x0_pred),
|
||||
self.config.dynamic_thresholding_ratio,
|
||||
axis=tuple(range(1, x0_pred.ndim)),
|
||||
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
|
||||
)
|
||||
dynamic_max_val = jnp.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * jnp.ones_like(dynamic_max_val),
|
||||
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
|
||||
)
|
||||
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
return x0_pred
|
||||
@@ -399,11 +385,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = (
|
||||
state.lambda_t[t],
|
||||
state.lambda_t[s0],
|
||||
state.lambda_t[s1],
|
||||
)
|
||||
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
@@ -461,12 +443,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
Returns:
|
||||
`jnp.ndarray`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = (
|
||||
prev_timestep,
|
||||
timestep_list[-1],
|
||||
timestep_list[-2],
|
||||
timestep_list[-3],
|
||||
)
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
state.lambda_t[t],
|
||||
@@ -638,10 +615,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
|
||||
@@ -19,7 +19,6 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -29,9 +28,6 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class EulerDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -44,18 +40,9 @@ class EulerDiscreteSchedulerState:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
sigmas: jnp.ndarray,
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
):
|
||||
return cls(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -112,10 +99,6 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
|
||||
@@ -163,10 +146,7 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
state: EulerDiscreteSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> EulerDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -179,12 +159,7 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = jnp.linspace(
|
||||
self.config.num_train_timesteps - 1,
|
||||
0,
|
||||
num_inference_steps,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
|
||||
@@ -22,13 +22,10 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class KarrasVeSchedulerState:
|
||||
# setable values
|
||||
@@ -105,10 +102,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
s_min: float = 0.05,
|
||||
s_max: float = 50,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
pass
|
||||
|
||||
def create_state(self):
|
||||
return KarrasVeSchedulerState.create()
|
||||
|
||||
@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int` or `torch.Tensor`:
|
||||
`int`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -20,7 +20,6 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -30,9 +29,6 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class LMSDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -48,18 +44,9 @@ class LMSDiscreteSchedulerState:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
sigmas: jnp.ndarray,
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
):
|
||||
return cls(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -114,10 +101,6 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
|
||||
@@ -182,10 +165,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
state: LMSDiscreteSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> LMSDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -197,12 +177,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
timesteps = jnp.linspace(
|
||||
self.config.num_train_timesteps - 1,
|
||||
0,
|
||||
num_inference_steps,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
|
||||
low_idx = jnp.floor(timesteps).astype(jnp.int32)
|
||||
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
|
||||
|
||||
@@ -22,7 +22,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,9 +31,6 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class PNDMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -135,10 +131,6 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
@@ -198,10 +190,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
else:
|
||||
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
|
||||
jnp.array(
|
||||
[0, self.config.num_train_timesteps // num_inference_steps // 2],
|
||||
dtype=jnp.int32,
|
||||
),
|
||||
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
|
||||
self.pndm_order,
|
||||
)
|
||||
|
||||
@@ -229,10 +218,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -334,9 +320,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
diff_to_prev = jnp.where(
|
||||
state.counter % 2,
|
||||
0,
|
||||
self.config.num_train_timesteps // state.num_inference_steps // 2,
|
||||
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
|
||||
)
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = state.prk_timesteps[state.counter // 4 * 4]
|
||||
@@ -417,9 +401,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
|
||||
timestep = jnp.where(
|
||||
state.counter == 1,
|
||||
timestep + self.config.num_train_timesteps // state.num_inference_steps,
|
||||
timestep,
|
||||
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
|
||||
)
|
||||
|
||||
# Reference:
|
||||
@@ -484,9 +466,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0,
|
||||
state.common.alphas_cumprod[prev_timestep],
|
||||
state.final_alpha_cumprod,
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -23,15 +23,7 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -103,10 +95,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
pass
|
||||
|
||||
def create_state(self):
|
||||
state = ScoreSdeVeSchedulerState.create()
|
||||
@@ -119,11 +108,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
state: ScoreSdeVeSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
sampling_eps: float = None,
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
|
||||
) -> ScoreSdeVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int` or `torch.Tensor`:
|
||||
`int`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -23,6 +23,7 @@ from .constants import (
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
DIFFUSERS_LOAD_ID_FIELDS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
|
||||
@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
|
||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
DIFFUSERS_LOAD_ID_FIELDS = [
|
||||
"pretrained_model_name_or_path",
|
||||
"subfolder",
|
||||
"variant",
|
||||
"revision",
|
||||
]
|
||||
|
||||
@@ -17,51 +17,6 @@ class Flux2AutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -276,74 +276,3 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
|
||||
# Test 1: Run with None mask (no padding, all tokens are valid)
|
||||
inputs_no_mask = inputs.copy()
|
||||
inputs_no_mask["encoder_hidden_states_mask"] = None
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_all_ones = model(**inputs_all_ones)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
|
||||
|
||||
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_with_padding = model(**inputs_with_padding)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -1,91 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -169,7 +169,7 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5849247, 0.50278825, 0.45747858, 0.45895284, 0.43804976, 0.47044256, 0.5239665, 0.47904694, 0.3323419, 0.38725388, 0.28505728, 0.3161863, 0.35026982, 0.37546024, 0.4090118, 0.46629113
|
||||
0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
@@ -177,109 +177,20 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(image.shape, (3, 32, 32))
|
||||
self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
"""Test that batch=1 produces consistent results with the same seed."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
# Run twice with same seed
|
||||
inputs1 = self.get_dummy_inputs(device, seed=42)
|
||||
inputs2 = self.get_dummy_inputs(device, seed=42)
|
||||
|
||||
image1 = pipe(**inputs1).images[0]
|
||||
image2 = pipe(**inputs2).images[0]
|
||||
|
||||
self.assertTrue(torch.allclose(image1, image2, atol=1e-4))
|
||||
|
||||
def test_inference_batch_multiple_prompts(self):
|
||||
"""Test batch processing with multiple prompts."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": ["A photo of a cat", "A photo of a dog"],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 2 images
|
||||
self.assertEqual(len(images), 2)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
self.assertEqual(images[1].shape, (3, 32, 32))
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_consistent(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_num_images_per_prompt(self):
|
||||
"""Test generating multiple images per prompt."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": "A photo of a cat",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"num_images_per_prompt": 2,
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 2 images for single prompt
|
||||
self.assertEqual(len(images), 2)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
self.assertEqual(images[1].shape, (3, 32, 32))
|
||||
|
||||
def test_batch_with_num_images_per_prompt(self):
|
||||
"""Test batch prompts with num_images_per_prompt > 1."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": ["A photo of a cat", "A photo of a dog"],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"num_images_per_prompt": 2,
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 4 images (2 prompts × 2 images per prompt)
|
||||
self.assertEqual(len(images), 4)
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Auto Docstring Generator for Modular Pipeline Blocks
|
||||
|
||||
This script scans Python files for classes that have `# auto_docstring` comment above them
|
||||
and inserts/updates the docstring from the class's `doc` property.
|
||||
|
||||
Run from the root of the repo:
|
||||
python utils/modular_auto_docstring.py [path] [--fix_and_overwrite]
|
||||
|
||||
Examples:
|
||||
# Check for auto_docstring markers (will error if found without proper docstring)
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Check specific directory
|
||||
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/
|
||||
|
||||
# Fix and overwrite the docstrings
|
||||
python utils/modular_auto_docstring.py --fix_and_overwrite
|
||||
|
||||
Usage in code:
|
||||
# auto_docstring
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
# docstring will be automatically inserted here
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return "Your docstring content..."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo
|
||||
DIFFUSERS_PATH = "src/diffusers"
|
||||
REPO_PATH = "."
|
||||
|
||||
# Pattern to match the auto_docstring comment
|
||||
AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$")
|
||||
|
||||
|
||||
def setup_diffusers_import():
|
||||
"""Setup import path to use the local diffusers module."""
|
||||
src_path = os.path.join(REPO_PATH, "src")
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
|
||||
def get_module_from_filepath(filepath: str) -> str:
|
||||
"""Convert a filepath to a module name."""
|
||||
filepath = os.path.normpath(filepath)
|
||||
|
||||
if filepath.startswith("src" + os.sep):
|
||||
filepath = filepath[4:]
|
||||
|
||||
if filepath.endswith(".py"):
|
||||
filepath = filepath[:-3]
|
||||
|
||||
module_name = filepath.replace(os.sep, ".")
|
||||
return module_name
|
||||
|
||||
|
||||
def load_module(filepath: str):
|
||||
"""Load a module from filepath."""
|
||||
setup_diffusers_import()
|
||||
module_name = get_module_from_filepath(filepath)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not import module {module_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_doc_from_class(module, class_name: str) -> str:
|
||||
"""Get the doc property from an instantiated class."""
|
||||
if module is None:
|
||||
return None
|
||||
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
instance = cls()
|
||||
if hasattr(instance, "doc"):
|
||||
return instance.doc
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not instantiate {class_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_auto_docstring_classes(filepath: str) -> list:
|
||||
"""
|
||||
Find all classes in a file that have # auto_docstring comment above them.
|
||||
|
||||
Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line)
|
||||
"""
|
||||
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Parse AST to find class locations and their docstrings
|
||||
content = "".join(lines)
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in {filepath}: {e}")
|
||||
return []
|
||||
|
||||
# Build a map of class_name -> (class_line, has_docstring, docstring_end_line)
|
||||
class_info = {}
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
has_docstring = False
|
||||
docstring_end_line = node.lineno # default to class line
|
||||
|
||||
if node.body and isinstance(node.body[0], ast.Expr):
|
||||
first_stmt = node.body[0]
|
||||
if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str):
|
||||
has_docstring = True
|
||||
docstring_end_line = first_stmt.end_lineno or first_stmt.lineno
|
||||
|
||||
class_info[node.name] = (node.lineno, has_docstring, docstring_end_line)
|
||||
|
||||
# Now scan for # auto_docstring comments
|
||||
classes_to_update = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if AUTO_DOCSTRING_PATTERN.match(line):
|
||||
# Found the marker, look for class definition on next non-empty, non-comment line
|
||||
j = i + 1
|
||||
while j < len(lines):
|
||||
next_line = lines[j].strip()
|
||||
if next_line and not next_line.startswith("#"):
|
||||
break
|
||||
j += 1
|
||||
|
||||
if j < len(lines) and lines[j].strip().startswith("class "):
|
||||
# Extract class name
|
||||
match = re.match(r"class\s+(\w+)", lines[j].strip())
|
||||
if match:
|
||||
class_name = match.group(1)
|
||||
if class_name in class_info:
|
||||
class_line, has_docstring, docstring_end_line = class_info[class_name]
|
||||
classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line))
|
||||
|
||||
return classes_to_update
|
||||
|
||||
|
||||
def strip_class_name_line(doc: str, class_name: str) -> str:
|
||||
"""Remove the 'class ClassName' line from the doc if present."""
|
||||
lines = doc.strip().split("\n")
|
||||
if lines and lines[0].strip() == f"class {class_name}":
|
||||
# Remove the class line and any blank line following it
|
||||
lines = lines[1:]
|
||||
while lines and not lines[0].strip():
|
||||
lines = lines[1:]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_docstring(doc: str, indent: str = " ") -> str:
|
||||
"""Format a doc string as a properly indented docstring."""
|
||||
lines = doc.strip().split("\n")
|
||||
|
||||
if len(lines) == 1:
|
||||
return f'{indent}"""{lines[0]}"""\n'
|
||||
else:
|
||||
result = [f'{indent}"""\n']
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
result.append(f"{indent}{line}\n")
|
||||
else:
|
||||
result.append("\n")
|
||||
result.append(f'{indent}"""\n')
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def run_ruff_format(filepath: str):
|
||||
"""Run ruff check --fix, ruff format, and doc-builder style on a file to ensure consistent formatting."""
|
||||
try:
|
||||
# First run ruff check --fix to fix any linting issues (including line length)
|
||||
subprocess.run(
|
||||
["ruff", "check", "--fix", filepath],
|
||||
check=False, # Don't fail if there are unfixable issues
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
# Then run ruff format for code formatting
|
||||
subprocess.run(
|
||||
["ruff", "format", filepath],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
# Finally run doc-builder style for docstring formatting
|
||||
subprocess.run(
|
||||
["doc-builder", "style", filepath, "--max_len", "119"],
|
||||
check=False, # Don't fail if doc-builder has issues
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print(f"Formatted {filepath}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Warning: formatting failed for {filepath}: {e.stderr}")
|
||||
except FileNotFoundError as e:
|
||||
print(f"Warning: tool not found ({e}). Skipping formatting.")
|
||||
except Exception as e:
|
||||
print(f"Warning: unexpected error formatting {filepath}: {e}")
|
||||
|
||||
|
||||
def get_existing_docstring(lines: list, class_line: int, docstring_end_line: int) -> str:
|
||||
"""Extract the existing docstring content from lines."""
|
||||
# class_line is 1-indexed, docstring starts at class_line (0-indexed: class_line)
|
||||
# and ends at docstring_end_line (1-indexed, inclusive)
|
||||
docstring_lines = lines[class_line:docstring_end_line]
|
||||
return "".join(docstring_lines)
|
||||
|
||||
|
||||
def process_file(filepath: str, overwrite: bool = False) -> list:
|
||||
"""
|
||||
Process a file and find/insert docstrings for # auto_docstring marked classes.
|
||||
|
||||
Returns list of classes that need updating.
|
||||
"""
|
||||
classes_to_update = find_auto_docstring_classes(filepath)
|
||||
|
||||
if not classes_to_update:
|
||||
return []
|
||||
|
||||
if not overwrite:
|
||||
# Check mode: only verify that docstrings exist
|
||||
# Content comparison is not reliable due to formatting differences
|
||||
classes_needing_update = []
|
||||
for class_name, class_line, has_docstring, docstring_end_line in classes_to_update:
|
||||
if not has_docstring:
|
||||
# No docstring exists, needs update
|
||||
classes_needing_update.append((filepath, class_name, class_line))
|
||||
return classes_needing_update
|
||||
|
||||
# Load the module to get doc properties
|
||||
module = load_module(filepath)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Process in reverse order to maintain line numbers
|
||||
updated = False
|
||||
for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update):
|
||||
doc = get_doc_from_class(module, class_name)
|
||||
|
||||
if doc is None:
|
||||
print(f"Warning: Could not get doc for {class_name} in {filepath}")
|
||||
continue
|
||||
|
||||
# Remove the "class ClassName" line since it's redundant in a docstring
|
||||
doc = strip_class_name_line(doc, class_name)
|
||||
|
||||
# Format the new docstring with 4-space indent
|
||||
new_docstring = format_docstring(doc, " ")
|
||||
|
||||
if has_docstring:
|
||||
# Replace existing docstring (line after class definition to docstring_end_line)
|
||||
# class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line
|
||||
lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:]
|
||||
else:
|
||||
# Insert new docstring right after class definition line
|
||||
# class_line is 1-indexed, so lines[class_line-1] is the class line
|
||||
# Insert at position class_line (which is right after the class line)
|
||||
lines = lines[:class_line] + [new_docstring] + lines[class_line:]
|
||||
|
||||
updated = True
|
||||
print(f"Updated docstring for {class_name} in {filepath}")
|
||||
|
||||
if updated:
|
||||
with open(filepath, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines)
|
||||
# Run ruff format to ensure consistent line wrapping
|
||||
run_ruff_format(filepath)
|
||||
|
||||
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
|
||||
|
||||
|
||||
def check_auto_docstrings(path: str = None, overwrite: bool = False):
|
||||
"""
|
||||
Check all files for # auto_docstring markers and optionally fix them.
|
||||
"""
|
||||
if path is None:
|
||||
path = DIFFUSERS_PATH
|
||||
|
||||
if os.path.isfile(path):
|
||||
all_files = [path]
|
||||
else:
|
||||
all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True)
|
||||
|
||||
all_markers = []
|
||||
|
||||
for filepath in all_files:
|
||||
markers = process_file(filepath, overwrite)
|
||||
all_markers.extend(markers)
|
||||
|
||||
if not overwrite and len(all_markers) > 0:
|
||||
message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers])
|
||||
raise ValueError(
|
||||
f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n"
|
||||
f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
if overwrite and len(all_markers) > 0:
|
||||
print(f"\nProcessed {len(all_markers)} docstring(s).")
|
||||
elif not overwrite and len(all_markers) == 0:
|
||||
print("All # auto_docstring markers have valid docstrings.")
|
||||
elif len(all_markers) == 0:
|
||||
print("No # auto_docstring markers found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Check and fix # auto_docstring markers in modular pipeline blocks",
|
||||
)
|
||||
parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)")
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite",
|
||||
action="store_true",
|
||||
help="Whether to fix the docstrings by inserting them from doc property.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
check_auto_docstrings(args.path, args.fix_and_overwrite)
|
||||
Reference in New Issue
Block a user