mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 06:54:54 +08:00
Compare commits
15 Commits
modular-lo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ab2124958 | ||
|
|
74a0f0b694 | ||
|
|
2c669e8480 | ||
|
|
2ac39ba664 | ||
|
|
ef913010d4 | ||
|
|
53d8a1e310 | ||
|
|
d54669a73e | ||
|
|
22ac6fae24 | ||
|
|
71a865b742 | ||
|
|
53279ef017 | ||
|
|
d9959bd53b | ||
|
|
b1c77f67ac | ||
|
|
956bdcc3ea | ||
|
|
2af7baa040 | ||
|
|
a7cb14efbe |
8
.github/workflows/build_docker_images.yml
vendored
8
.github/workflows/build_docker_images.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
if: github.event_name == 'pull_request'
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v6
|
||||
@@ -101,14 +101,14 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.REGISTRY }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
no-cache: true
|
||||
context: ./docker/${{ matrix.image-name }}
|
||||
|
||||
20
.github/workflows/pr_modular_tests.yml
vendored
20
.github/workflows/pr_modular_tests.yml
vendored
@@ -75,9 +75,27 @@ jobs:
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
check_auto_docs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check auto docs
|
||||
run: make modular-autodoctrings
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
needs: [check_code_quality, check_repository_consistency, check_auto_docs]
|
||||
name: Fast PyTorch Modular Pipeline CPU tests
|
||||
|
||||
runs-on:
|
||||
|
||||
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -11,4 +11,4 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.12.4
|
||||
uses: crate-ci/typos@v1.42.1
|
||||
|
||||
4
Makefile
4
Makefile
@@ -70,6 +70,10 @@ fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
|
||||
# Auto docstrings in modular blocks
|
||||
modular-autodoctrings:
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
test:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
FROM nvidia/cuda:12.9.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -36,8 +36,12 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
torchaudio
|
||||
|
||||
# Install compatible versions of numba/llvmlite for Python 3.10+
|
||||
RUN uv pip install --no-cache-dir \
|
||||
"llvmlite>=0.40.0" \
|
||||
"numba>=0.57.0"
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
FROM nvidia/cuda:12.9.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -36,8 +36,12 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
torchaudio
|
||||
|
||||
# Install compatible versions of numba/llvmlite for Python 3.10+
|
||||
RUN uv pip install --no-cache-dir \
|
||||
"llvmlite>=0.40.0" \
|
||||
"numba>=0.57.0"
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -83,25 +83,6 @@ Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-
|
||||
> [!TIP]
|
||||
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
|
||||
|
||||
## autoquant
|
||||
|
||||
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from torchao.quantization import autoquant
|
||||
|
||||
# Load the pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
transformer = autoquant(pipeline.transformer)
|
||||
```
|
||||
|
||||
## Supported quantization types
|
||||
|
||||
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
|
||||
|
||||
@@ -1467,7 +1467,8 @@ def main(args):
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
|
||||
@@ -413,6 +413,9 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
@@ -1146,6 +1149,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
|
||||
@@ -152,6 +152,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanAnimateTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
|
||||
@@ -136,6 +136,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"wan_vace": "vace_blocks.0.after_proj.bias",
|
||||
"wan_animate": "motion_encoder.dec.direction.weight",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
"cosmos-1.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
@@ -219,6 +220,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
|
||||
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
||||
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
@@ -759,6 +761,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif checkpoint[target_key].shape[0] == 5120:
|
||||
model_type = "wan-vace-14B"
|
||||
|
||||
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
|
||||
model_type = "wan-animate-14B"
|
||||
|
||||
elif checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-t2v-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
@@ -3154,13 +3159,64 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def generate_motion_encoder_mappings():
|
||||
mappings = {
|
||||
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
||||
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
|
||||
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
|
||||
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
|
||||
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
||||
}
|
||||
|
||||
for i in range(7):
|
||||
conv_idx = i + 1
|
||||
mappings.update(
|
||||
{
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
|
||||
}
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
def generate_face_adapter_mappings():
|
||||
return {
|
||||
"face_adapter.fuser_blocks": "face_adapter",
|
||||
".k_norm.": ".norm_k.",
|
||||
".q_norm.": ".norm_q.",
|
||||
".linear1_q.": ".to_q.",
|
||||
".linear2.": ".to_out.",
|
||||
"conv1_local.conv": "conv1_local",
|
||||
"conv2.conv": "conv2",
|
||||
"conv3.conv": "conv3",
|
||||
}
|
||||
|
||||
def split_tensor_handler(key, state_dict, split_pattern, target_keys):
|
||||
tensor = state_dict.pop(key)
|
||||
split_idx = tensor.shape[0] // 2
|
||||
|
||||
new_key_1 = key.replace(split_pattern, target_keys[0])
|
||||
new_key_2 = key.replace(split_pattern, target_keys[1])
|
||||
|
||||
state_dict[new_key_1] = tensor[:split_idx]
|
||||
state_dict[new_key_2] = tensor[split_idx:]
|
||||
|
||||
def reshape_bias_handler(key, state_dict):
|
||||
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
|
||||
state_dict[key] = state_dict[key][0, :, 0, 0]
|
||||
|
||||
converted_state_dict = {}
|
||||
|
||||
# Strip model.diffusion_model prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Base transformer mappings
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
@@ -3182,28 +3238,43 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# For the I2V model
|
||||
# I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# For the VACE model
|
||||
# VACE model
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
SPECIAL_KEYS_HANDLERS = {}
|
||||
if any("face_adapter" in k for k in checkpoint.keys()):
|
||||
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
|
||||
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])
|
||||
|
||||
if any("motion_encoder" in k for k in checkpoint.keys()):
|
||||
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
reshape_bias_handler(key, checkpoint)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
converted_state_dict[new_key] = checkpoint.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
|
||||
if pattern not in key:
|
||||
continue
|
||||
handler_fn(key, converted_state_dict, pattern, target_keys)
|
||||
break
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from tokenizers import Tokenizer as TokenizerFast
|
||||
from torch import nn
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
@@ -547,23 +549,39 @@ class TextualInversionLoaderMixin:
|
||||
else:
|
||||
last_special_token_id = added_token_id
|
||||
|
||||
# Delete from tokenizer
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
|
||||
# Make all token ids sequential in tokenizer
|
||||
key_id = 1
|
||||
for token_id in tokenizer.added_tokens_decoder:
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
# Fast tokenizers (v5+)
|
||||
if hasattr(tokenizer, "_tokenizer"):
|
||||
# Fast tokenizers: serialize, filter tokens, reload
|
||||
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
|
||||
new_id = last_special_token_id + 1
|
||||
filtered = []
|
||||
for tok in tokenizer_json.get("added_tokens", []):
|
||||
if tok.get("content") in set(tokens):
|
||||
continue
|
||||
if not tok.get("special", False):
|
||||
tok["id"] = new_id
|
||||
new_id += 1
|
||||
filtered.append(tok)
|
||||
tokenizer_json["added_tokens"] = filtered
|
||||
tokenizer._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
|
||||
else:
|
||||
# Slow tokenizers
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
tokenizer._update_trie()
|
||||
# set correct total vocab size after removing tokens
|
||||
tokenizer._update_total_vocab_size()
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
|
||||
key_id = 1
|
||||
for token_id in list(tokenizer.added_tokens_decoder.keys()):
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
if hasattr(tokenizer, "_update_trie"):
|
||||
tokenizer._update_trie()
|
||||
if hasattr(tokenizer, "_update_total_vocab_size"):
|
||||
tokenizer._update_total_vocab_size()
|
||||
|
||||
# Delete from text encoder
|
||||
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
|
||||
|
||||
@@ -143,41 +143,86 @@ class GlmImageAdaLayerNormZero(nn.Module):
|
||||
|
||||
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model."""
|
||||
"""KV cache for GlmImage model.
|
||||
Supports per-sample caching for batch processing where each sample may have different condition images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.k_caches: List[Optional[torch.Tensor]] = []
|
||||
self.v_caches: List[Optional[torch.Tensor]] = []
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
self.current_sample_idx: int = 0 # Current sample index for writing
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
"""Store KV cache for the current sample."""
|
||||
# k, v shape: (1, seq_len, num_heads, head_dim)
|
||||
if len(self.k_caches) <= self.current_sample_idx:
|
||||
# First time storing for this sample
|
||||
self.k_caches.append(k)
|
||||
self.v_caches.append(v)
|
||||
else:
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=1)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=1)
|
||||
# Append to existing cache for this sample (multiple condition images)
|
||||
self.k_caches[self.current_sample_idx] = torch.cat([self.k_caches[self.current_sample_idx], k], dim=1)
|
||||
self.v_caches[self.current_sample_idx] = torch.cat([self.v_caches[self.current_sample_idx], v], dim=1)
|
||||
|
||||
def get(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache.shape[0] != k.shape[0]:
|
||||
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
|
||||
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
|
||||
else:
|
||||
k_cache_expanded = self.k_cache
|
||||
v_cache_expanded = self.v_cache
|
||||
"""Get combined KV cache for all samples in the batch.
|
||||
|
||||
k_cache = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_cache = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_cache, v_cache
|
||||
Args:
|
||||
k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim)
|
||||
v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim)
|
||||
Returns:
|
||||
Combined key and value tensors with cached values prepended.
|
||||
"""
|
||||
batch_size = k.shape[0]
|
||||
num_cached_samples = len(self.k_caches)
|
||||
if num_cached_samples == 0:
|
||||
return k, v
|
||||
if num_cached_samples == 1:
|
||||
# Single cache, expand for all batch samples (shared condition images)
|
||||
k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1)
|
||||
v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1)
|
||||
elif num_cached_samples == batch_size:
|
||||
# Per-sample cache, concatenate along batch dimension
|
||||
k_cache_expanded = torch.cat(self.k_caches, dim=0)
|
||||
v_cache_expanded = torch.cat(self.v_caches, dim=0)
|
||||
else:
|
||||
# Mismatch: try to handle by repeating the caches
|
||||
# This handles cases like num_images_per_prompt > 1
|
||||
repeat_factor = batch_size // num_cached_samples
|
||||
if batch_size % num_cached_samples == 0:
|
||||
k_cache_list = []
|
||||
v_cache_list = []
|
||||
for i in range(num_cached_samples):
|
||||
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
|
||||
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
|
||||
k_cache_expanded = torch.cat(k_cache_list, dim=0)
|
||||
v_cache_expanded = torch.cat(v_cache_list, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. "
|
||||
f"Batch size must be a multiple of the number of cached samples."
|
||||
)
|
||||
|
||||
k_combined = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_combined = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_combined, v_combined
|
||||
|
||||
def clear(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
self.mode = None
|
||||
self.current_sample_idx = 0
|
||||
|
||||
def next_sample(self):
|
||||
"""Move to the next sample for writing."""
|
||||
self.current_sample_idx += 1
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches."""
|
||||
"""Container for all layers' KV caches.
|
||||
Supports per-sample caching for batch processing where each sample may have different condition images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
@@ -192,6 +237,12 @@ class GlmImageKVCache:
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def next_sample(self):
|
||||
"""Move to the next sample for writing. Call this after processing
|
||||
all condition images for one batch sample."""
|
||||
for cache in self.caches:
|
||||
cache.next_sample()
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
@@ -166,9 +166,11 @@ class MotionConv2d(nn.Module):
|
||||
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
|
||||
# set to 1, which should be equivalent to a 2D convolution
|
||||
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
|
||||
x = x.to(expanded_kernel.dtype)
|
||||
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
|
||||
|
||||
# Main Conv2D with scaling
|
||||
x = x.to(self.weight.dtype)
|
||||
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
# Activation with fused bias, if using
|
||||
@@ -338,8 +340,7 @@ class WanAnimateMotionEncoder(nn.Module):
|
||||
weight = self.motion_synthesis_weight + 1e-8
|
||||
# Upcast the QR orthogonalization operation to FP32
|
||||
original_motion_dtype = motion_feat.dtype
|
||||
motion_feat = motion_feat.to(torch.float32)
|
||||
weight = weight.to(torch.float32)
|
||||
motion_feat = motion_feat.to(weight.dtype)
|
||||
|
||||
Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
|
||||
|
||||
@@ -769,7 +770,7 @@ class WanImageEmbedding(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
# Modified from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -803,10 +804,12 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
if self.time_embedder.linear_1.weight.dtype.is_floating_point:
|
||||
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
|
||||
else:
|
||||
time_embedder_dtype = encoder_hidden_states.dtype
|
||||
|
||||
temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
|
||||
@@ -54,7 +54,10 @@ else:
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -81,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||
from .flux2 import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2TextInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks"] = [
|
||||
_import_structure["modular_blocks_flux2"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"REMOTE_AUTO_BLOCKS",
|
||||
@@ -51,10 +51,11 @@ else:
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2BeforeDenoiseStep",
|
||||
"Flux2CoreDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -85,7 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
from .modular_blocks_flux2 import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
@@ -93,10 +94,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2BeforeDenoiseStep,
|
||||
Flux2CoreDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
from .modular_blocks_flux2_klein import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
@@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
block_state.device,
|
||||
device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
@@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -353,7 +339,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="latent_ids"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -365,12 +350,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latent_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -403,6 +382,72 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="negative_prompt_embeds", required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
block_state.negative_txt_ids = None
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
|
||||
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -506,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the guidance scale tensor for Flux2 inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,29 +29,16 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
return "Step that unpacks the latents from the denoising step"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -70,9 +57,9 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
"latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoise latents from denoising step, unpacked with position IDs.",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -107,6 +94,62 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
block_state.latents = latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
"""Convert patchified latents back to regular format."""
|
||||
@@ -121,26 +164,20 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.images = block_state.latents
|
||||
else:
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
latents = block_state.latents
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -16,6 +16,8 @@ from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
@@ -25,8 +27,8 @@ from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -134,6 +136,229 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# same as Flux2LoopDenoiser but guidance=None
|
||||
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# support CFG for Flux2-Klein base model
|
||||
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"negative_txt_ids",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for negative text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"txt_ids": (
|
||||
getattr(block_state, "txt_ids", None),
|
||||
getattr(block_state, "negative_txt_ids", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
# perform guidance
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -250,3 +475,35 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinBaseLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
@@ -15,13 +15,15 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -79,10 +81,8 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -99,14 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -165,10 +158,6 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -205,7 +194,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -222,15 +210,8 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -244,10 +225,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -270,6 +247,289 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
else:
|
||||
block_state.negative_prompt_embeds = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -89,6 +89,90 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -12,16 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2PrepareGuidanceStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
@@ -41,7 +47,6 @@ Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -72,33 +77,56 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
Flux2CoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||
block_classes = Flux2CoreDenoiseBlocks.values()
|
||||
block_names = Flux2CoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-dev.\n"
|
||||
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -107,10 +135,8 @@ AUTO_BLOCKS = InsertableDict(
|
||||
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -130,6 +156,16 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
@@ -137,8 +173,10 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -152,8 +190,10 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,232 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
|
||||
from .encoders import (
|
||||
Flux2KleinBaseTextEncoderStep,
|
||||
Flux2KleinTextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2KleinBaseTextInputStep,
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
################
|
||||
# VAE encoder
|
||||
################
|
||||
|
||||
Flux2KleinVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2KleinVaeEncoderBlocks.values()
|
||||
block_names = Flux2KleinVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
|
||||
|
||||
|
||||
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2KleinVaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
###
|
||||
### Core denoise
|
||||
###
|
||||
|
||||
Flux2KleinCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2KleinDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = Flux2KleinCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
###
|
||||
### Auto blocks
|
||||
###
|
||||
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinBaseTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinBaseCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -55,3 +57,56 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
@@ -59,6 +59,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
|
||||
@@ -18,6 +18,7 @@ from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
@@ -323,11 +324,192 @@ class ConfigSpec:
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
|
||||
# however some fields are not relevant for intermediate_inputs
|
||||
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
|
||||
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
|
||||
# -> should we use different class for inputs and intermediate_inputs?
|
||||
# ======================================================
|
||||
# InputParam and OutputParam templates
|
||||
# ======================================================
|
||||
|
||||
INPUT_PARAM_TEMPLATES = {
|
||||
"prompt": {
|
||||
"type_hint": str,
|
||||
"required": True,
|
||||
"description": "The prompt or prompts to guide image generation.",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type_hint": str,
|
||||
"description": "The prompt or prompts not to guide the image generation.",
|
||||
},
|
||||
"max_sequence_length": {
|
||||
"type_hint": int,
|
||||
"default": 512,
|
||||
"description": "Maximum sequence length for prompt encoding.",
|
||||
},
|
||||
"height": {
|
||||
"type_hint": int,
|
||||
"description": "The height in pixels of the generated image.",
|
||||
},
|
||||
"width": {
|
||||
"type_hint": int,
|
||||
"description": "The width in pixels of the generated image.",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"type_hint": int,
|
||||
"default": 50,
|
||||
"description": "The number of denoising steps.",
|
||||
},
|
||||
"num_images_per_prompt": {
|
||||
"type_hint": int,
|
||||
"default": 1,
|
||||
"description": "The number of images to generate per prompt.",
|
||||
},
|
||||
"generator": {
|
||||
"type_hint": torch.Generator,
|
||||
"description": "Torch generator for deterministic generation.",
|
||||
},
|
||||
"sigmas": {
|
||||
"type_hint": List[float],
|
||||
"description": "Custom sigmas for the denoising process.",
|
||||
},
|
||||
"strength": {
|
||||
"type_hint": float,
|
||||
"default": 0.9,
|
||||
"description": "Strength for img2img/inpainting.",
|
||||
},
|
||||
"image": {
|
||||
"type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
"required": True,
|
||||
"description": "Reference image(s) for denoising. Can be a single image or list of images.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Pre-generated noisy latents for image generation.",
|
||||
},
|
||||
"timesteps": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Timesteps for the denoising process.",
|
||||
},
|
||||
"output_type": {
|
||||
"type_hint": str,
|
||||
"default": "pil",
|
||||
"description": "Output format: 'pil', 'np', 'pt'.",
|
||||
},
|
||||
"attention_kwargs": {
|
||||
"type_hint": Dict[str, Any],
|
||||
"description": "Additional kwargs for attention processors.",
|
||||
},
|
||||
"denoiser_input_fields": {
|
||||
"name": None,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
},
|
||||
# inpainting
|
||||
"mask_image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Mask image for inpainting.",
|
||||
},
|
||||
"padding_mask_crop": {
|
||||
"type_hint": int,
|
||||
"description": "Padding for mask cropping in inpainting.",
|
||||
},
|
||||
# controlnet
|
||||
"control_image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Control image for ControlNet conditioning.",
|
||||
},
|
||||
"control_guidance_start": {
|
||||
"type_hint": float,
|
||||
"default": 0.0,
|
||||
"description": "When to start applying ControlNet.",
|
||||
},
|
||||
"control_guidance_end": {
|
||||
"type_hint": float,
|
||||
"default": 1.0,
|
||||
"description": "When to stop applying ControlNet.",
|
||||
},
|
||||
"controlnet_conditioning_scale": {
|
||||
"type_hint": float,
|
||||
"default": 1.0,
|
||||
"description": "Scale for ControlNet conditioning.",
|
||||
},
|
||||
"layers": {
|
||||
"type_hint": int,
|
||||
"default": 4,
|
||||
"description": "Number of layers to extract from the image",
|
||||
},
|
||||
# common intermediate inputs
|
||||
"prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"required": True,
|
||||
"description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.",
|
||||
},
|
||||
"prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"required": True,
|
||||
"description": "mask for the text embeddings. Can be generated from text_encoder step.",
|
||||
},
|
||||
"negative_prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.",
|
||||
},
|
||||
"negative_prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "mask for the negative text embeddings. Can be generated from text_encoder step.",
|
||||
},
|
||||
"image_latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"required": True,
|
||||
"description": "image latents used to guide the image generation. Can be generated from vae_encoder step.",
|
||||
},
|
||||
"batch_size": {
|
||||
"type_hint": int,
|
||||
"default": 1,
|
||||
"description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
},
|
||||
"dtype": {
|
||||
"type_hint": torch.dtype,
|
||||
"default": torch.float32,
|
||||
"description": "The dtype of the model inputs, can be generated in input step.",
|
||||
},
|
||||
}
|
||||
|
||||
OUTPUT_PARAM_TEMPLATES = {
|
||||
"images": {
|
||||
"type_hint": List[PIL.Image.Image],
|
||||
"description": "Generated images.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Denoised latents.",
|
||||
},
|
||||
# intermediate outputs
|
||||
"prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The prompt embeddings.",
|
||||
},
|
||||
"prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The encoder attention mask.",
|
||||
},
|
||||
"negative_prompt_embeds": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The negative prompt embeddings.",
|
||||
},
|
||||
"negative_prompt_embeds_mask": {
|
||||
"type_hint": torch.Tensor,
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"description": "The negative prompt embeddings mask.",
|
||||
},
|
||||
"image_latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "The latent representation of the input image.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
@@ -337,11 +519,31 @@ class InputParam:
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
|
||||
kwargs_type: str = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
@classmethod
|
||||
def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam":
|
||||
"""Get template for name if exists, otherwise raise ValueError."""
|
||||
if template_name not in INPUT_PARAM_TEMPLATES:
|
||||
raise ValueError(f"InputParam template for {template_name} not found")
|
||||
|
||||
template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy()
|
||||
|
||||
# Determine the actual param name:
|
||||
# 1. From overrides if provided
|
||||
# 2. From template if present
|
||||
# 3. Fall back to template_name
|
||||
name = overrides.pop("name", template_kwargs.pop("name", template_name))
|
||||
|
||||
if note and "description" in template_kwargs:
|
||||
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
|
||||
|
||||
template_kwargs.update(overrides)
|
||||
return cls(name=name, **template_kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
@@ -350,13 +552,33 @@ class OutputParam:
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
|
||||
kwargs_type: str = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam":
|
||||
"""Get template for name if exists, otherwise raise ValueError."""
|
||||
if template_name not in OUTPUT_PARAM_TEMPLATES:
|
||||
raise ValueError(f"OutputParam template for {template_name} not found")
|
||||
|
||||
template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy()
|
||||
|
||||
# Determine the actual param name:
|
||||
# 1. From overrides if provided
|
||||
# 2. From template if present
|
||||
# 3. Fall back to template_name
|
||||
name = overrides.pop("name", template_kwargs.pop("name", template_name))
|
||||
|
||||
if note and "description" in template_kwargs:
|
||||
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
|
||||
|
||||
template_kwargs.update(overrides)
|
||||
return cls(name=name, **template_kwargs)
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
"""
|
||||
@@ -509,10 +731,12 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
|
||||
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
||||
param_str += f"\n{desc_indent}{wrapped_desc}"
|
||||
else:
|
||||
param_str += f"\n{desc_indent}TODO: Add description."
|
||||
|
||||
formatted_params.append(param_str)
|
||||
|
||||
return "\n\n".join(formatted_params)
|
||||
return "\n".join(formatted_params)
|
||||
|
||||
|
||||
def format_input_params(input_params, indent_level=4, max_line_length=115):
|
||||
@@ -582,7 +806,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
|
||||
loading_field_values = []
|
||||
for field_name in component.loading_fields():
|
||||
field_value = getattr(component, field_name)
|
||||
if field_value is not None:
|
||||
if field_value:
|
||||
loading_field_values.append(f"{field_name}={field_value}")
|
||||
|
||||
# Add loading field information if available
|
||||
@@ -669,17 +893,17 @@ def make_doc_string(
|
||||
# Add description
|
||||
if description:
|
||||
desc_lines = description.strip().split("\n")
|
||||
aligned_desc = "\n".join(" " + line for line in desc_lines)
|
||||
aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines)
|
||||
output += aligned_desc + "\n\n"
|
||||
|
||||
# Add components section if provided
|
||||
if expected_components and len(expected_components) > 0:
|
||||
components_str = format_components(expected_components, indent_level=2)
|
||||
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
|
||||
output += components_str + "\n\n"
|
||||
|
||||
# Add configs section if provided
|
||||
if expected_configs and len(expected_configs) > 0:
|
||||
configs_str = format_configs(expected_configs, indent_level=2)
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
|
||||
@@ -118,7 +118,40 @@ def get_timesteps(scheduler, num_inference_steps, strength):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Prepare initial random noise for the generation process
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
height (`int`):
|
||||
if not set, updated to default value
|
||||
width (`int`):
|
||||
if not set, updated to default value
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -134,28 +167,20 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
InputParam.template("latents"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("generator"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("dtype"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -209,7 +234,42 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Prepare initial random noise (B, layers+1, C, H, W) for the generation process
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
height (`int`):
|
||||
if not set, updated to default value
|
||||
width (`int`):
|
||||
if not set, updated to default value
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -225,29 +285,21 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="layers", default=4),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
InputParam.template("latents"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam.template("layers"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("generator"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("dtype"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -301,7 +353,31 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps,
|
||||
prepare_latents. Both noise and image latents should alreadybe patchified.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
|
||||
generated from vae encoder and updated in input step.)
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
|
||||
Outputs:
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
latents (`Tensor`):
|
||||
The scaled noisy latents to use for inpainting/image-to-image denoising.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -323,12 +399,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised, can be generated in prepare latent step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
|
||||
),
|
||||
InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."),
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
required=True,
|
||||
@@ -345,6 +416,11 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised used for inpainting denoising.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The scaled noisy latents to use for inpainting/image-to-image denoising.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -382,7 +458,29 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that creates mask latents from preprocessed mask_image by interpolating to latent space.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -404,9 +502,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The processed mask to use for the inpainting process.",
|
||||
),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="dtype", required=True),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("dtype"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -450,7 +548,27 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`):
|
||||
The initial random noised latents for the denoising process. Can be generated in prepare latents step.
|
||||
|
||||
Outputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -466,13 +584,13 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -516,7 +634,27 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -532,15 +670,17 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50, type_hint=int),
|
||||
InputParam("sigmas", type_hint=List[float]),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam.template("image_latents"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="timesteps", type_hint=torch.Tensor),
|
||||
OutputParam(
|
||||
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process."
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -574,7 +714,32 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare
|
||||
latents step.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`):
|
||||
The latents to use for the denoising process. Can be generated in prepare latents step.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
|
||||
Outputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time. Updated based on strength.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -590,15 +755,15 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
description="The latents to use for the denoising process. Can be generated in prepare latents step.",
|
||||
),
|
||||
InputParam(name="strength", default=0.9),
|
||||
InputParam.template("strength", default=0.9),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -607,7 +772,12 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
name="timesteps",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
description="The timesteps to use for the denoising process.",
|
||||
),
|
||||
OutputParam(
|
||||
name="num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time. Updated based on strength.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -654,7 +824,29 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
## RoPE inputs for denoiser
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the images latents, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -666,11 +858,11 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -702,7 +894,34 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after
|
||||
prepare_latents step
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_height (`int`):
|
||||
The height of the reference image. Can be generated in input step.
|
||||
image_width (`int`):
|
||||
The width of the reference image. Can be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the images latents, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -712,13 +931,23 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True),
|
||||
InputParam(name="image_width", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam(
|
||||
name="image_height",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The height of the reference image. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_width",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The width of the reference image. Can be generated in input step.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -756,7 +985,39 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.
|
||||
Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed
|
||||
after prepare_latents step.
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_height (`List`):
|
||||
The heights of the reference images. Can be generated in input step.
|
||||
image_width (`List`):
|
||||
The widths of the reference images. Can be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the image latents, used for RoPE calculation
|
||||
txt_seq_lens (`List`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation
|
||||
negative_txt_seq_lens (`List`):
|
||||
The sequence lengths of the negative prompt embeds, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
@@ -770,13 +1031,23 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True, type_hint=List[int]),
|
||||
InputParam(name="image_width", required=True, type_hint=List[int]),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam(
|
||||
name="image_height",
|
||||
required=True,
|
||||
type_hint=List[int],
|
||||
description="The heights of the reference images. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_width",
|
||||
required=True,
|
||||
type_hint=List[int],
|
||||
description="The widths of the reference images. Can be generated in input step.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -832,7 +1103,37 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step
|
||||
|
||||
Inputs:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
The shapes of the image latents, used for RoPE calculation
|
||||
txt_seq_lens (`List`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation
|
||||
negative_txt_seq_lens (`List`):
|
||||
The sequence lengths of the negative prompt embeds, used for RoPE calculation
|
||||
additional_t_cond (`Tensor`):
|
||||
The additional t cond, used for RoPE calculation
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -844,12 +1145,12 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="layers", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("layers"),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -914,7 +1215,34 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
"""
|
||||
step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step.
|
||||
|
||||
Components:
|
||||
controlnet (`QwenImageControlNetModel`)
|
||||
|
||||
Inputs:
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
|
||||
Outputs:
|
||||
controlnet_keep (`List`):
|
||||
The controlnet keep values
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -930,12 +1258,17 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("control_guidance_start", default=0.0),
|
||||
InputParam("control_guidance_end", default=1.0),
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("control_image_latents", required=True),
|
||||
InputParam.template("control_guidance_start"),
|
||||
InputParam.template("control_guidance_end"),
|
||||
InputParam.template("controlnet_conditioning_scale"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
name="control_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
|
||||
@@ -12,10 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -31,7 +29,30 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# after denoising loop (unpack latents)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size,
|
||||
channels, 1, height, width)
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
latents (`Tensor`):
|
||||
The latents to decode, can be generated in the denoise step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The denoisedlatents unpacked to B, C, 1, H, W
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -49,13 +70,21 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
description="The latents to decode, can be generated in the denoise step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W"
|
||||
),
|
||||
]
|
||||
|
||||
@@ -72,7 +101,29 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents. (unpacked to B, C, layers+1, H, W)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -88,10 +139,21 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("height", required=True, type_hint=int),
|
||||
InputParam("width", required=True, type_hint=int),
|
||||
InputParam("layers", required=True, type_hint=int),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents to decode, can be generated in the denoise step.",
|
||||
),
|
||||
InputParam.template("height", required=True),
|
||||
InputParam.template("width", required=True),
|
||||
InputParam.template("layers"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -112,7 +174,26 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
# decode step
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Step that decodes the latents to images
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -134,19 +215,13 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam.template("images", note="tensor output of the vae decoder.")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -176,7 +251,26 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Decode unpacked latents (B, C, layers+1, H, W) into layer images.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
@@ -198,14 +292,19 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.",
|
||||
),
|
||||
InputParam.template("output_type"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -251,7 +350,27 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
# postprocess the decoded images
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
"""
|
||||
postprocess the generated image
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
images (`Tensor`):
|
||||
the generated image tensor from decoders step
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -272,15 +391,19 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
name="images",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="the generated image tensor from decoders step",
|
||||
),
|
||||
InputParam.template("output_type"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam.template("images")]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
@@ -301,7 +424,28 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
"""
|
||||
postprocess the generated image, optional apply the mask overally to the original image..
|
||||
|
||||
Components:
|
||||
image_mask_processor (`InpaintProcessor`)
|
||||
|
||||
Inputs:
|
||||
images (`Tensor`):
|
||||
the generated image tensor from decoders step
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -322,16 +466,24 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
name="images",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="the generated image tensor from decoders step",
|
||||
),
|
||||
InputParam.template("output_type"),
|
||||
InputParam(
|
||||
name="mask_overlay_kwargs",
|
||||
type_hint=Dict[str, Any],
|
||||
description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.",
|
||||
),
|
||||
InputParam("mask_overlay_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam.template("images")]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type, mask_overlay_kwargs):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
|
||||
@@ -50,7 +50,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
@@ -80,17 +80,12 @@ class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
|
||||
),
|
||||
InputParam.template("image_latents"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -134,29 +129,12 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."),
|
||||
InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=float,
|
||||
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_keep",
|
||||
name="controlnet_keep",
|
||||
required=True,
|
||||
type_hint=List[float],
|
||||
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds."
|
||||
),
|
||||
description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -217,28 +195,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam.template("attention_kwargs"),
|
||||
InputParam.template("denoiser_input_fields"),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
type_hint=List[Tuple[int, int]],
|
||||
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
|
||||
description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -317,23 +280,8 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam.template("attention_kwargs"),
|
||||
InputParam.template("denoiser_input_fields"),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
@@ -415,7 +363,7 @@ class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -456,24 +404,19 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam.template("image_latents"),
|
||||
InputParam(
|
||||
"initial_noise",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -515,17 +458,12 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
name="timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam.template("num_inference_steps", required=True),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -557,7 +495,42 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image (text2image, image2image)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports text2image and image2image tasks for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [
|
||||
@@ -570,8 +543,8 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"Denoise step that iteratively denoise the latents.\n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
@@ -581,7 +554,47 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image (inpainting)
|
||||
# auto_docstring
|
||||
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
- `QwenImageLoopAfterDenoiserInpaint`
|
||||
This block supports inpainting tasks for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
initial_noise (`Tensor`):
|
||||
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
@@ -606,7 +619,47 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image (text2image, image2image) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopBeforeDenoiserControlNet`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports text2img/img2img tasks with controlnet for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer
|
||||
(`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
|
||||
controlnet_keep (`List`):
|
||||
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
@@ -631,7 +684,54 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageLoopBeforeDenoiser`
|
||||
- `QwenImageLoopBeforeDenoiserControlNet`
|
||||
- `QwenImageLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
- `QwenImageLoopAfterDenoiserInpaint`
|
||||
This block supports inpainting tasks with controlnet for QwenImage.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer
|
||||
(`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
|
||||
controlnet_keep (`List`):
|
||||
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
initial_noise (`Tensor`):
|
||||
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
@@ -664,7 +764,42 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image)
|
||||
# auto_docstring
|
||||
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageEditLoopBeforeDenoiser`
|
||||
- `QwenImageEditLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports QwenImage Edit.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
@@ -687,7 +822,47 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting)
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageEditLoopBeforeDenoiser`
|
||||
- `QwenImageEditLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
- `QwenImageLoopAfterDenoiserInpaint`
|
||||
This block supports inpainting tasks for QwenImage Edit.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
initial_noise (`Tensor`):
|
||||
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
@@ -712,7 +887,42 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image)
|
||||
# auto_docstring
|
||||
class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
|
||||
defined in `sub_blocks` sequencially:
|
||||
- `QwenImageEditLoopBeforeDenoiser`
|
||||
- `QwenImageEditLoopDenoiser`
|
||||
- `QwenImageLoopAfterDenoiser`
|
||||
This block supports QwenImage Layered.
|
||||
|
||||
Components:
|
||||
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -109,7 +109,44 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in
|
||||
return height, width
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
Text input processing step that standardizes text embeddings for the pipeline.
|
||||
This step:
|
||||
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
|
||||
2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)
|
||||
|
||||
This block should be placed after all encoder steps to process the text embeddings before they are used in
|
||||
subsequent pipeline steps.
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -129,26 +166,22 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("prompt_embeds"),
|
||||
InputParam.template("prompt_embeds_mask"),
|
||||
InputParam.template("negative_prompt_embeds"),
|
||||
InputParam.template("negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"),
|
||||
OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"),
|
||||
OutputParam.template("prompt_embeds", note="batch-expanded"),
|
||||
OutputParam.template("prompt_embeds_mask", note="batch-expanded"),
|
||||
OutputParam.template("negative_prompt_embeds", note="batch-expanded"),
|
||||
OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -221,20 +254,76 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage: update height/width, expand batch, patchify."""
|
||||
"""
|
||||
Input processing step that:
|
||||
1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size
|
||||
2. For additional batch inputs: Expands batch dimensions to match final batch size
|
||||
|
||||
Configured inputs:
|
||||
- Image latent inputs: ['image_latents']
|
||||
|
||||
This block should be placed after the encoder steps and the text input step.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
image_latent_inputs: Optional[List[InputParam]] = None,
|
||||
additional_batch_inputs: Optional[List[InputParam]] = None,
|
||||
):
|
||||
# by default, process `image_latents`
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
@@ -252,9 +341,9 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
@@ -269,23 +358,19 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
# default is `image_latents`
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
outputs = [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
@@ -298,11 +383,43 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
),
|
||||
]
|
||||
|
||||
# `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided
|
||||
if len(self._image_latent_inputs) > 0:
|
||||
outputs.append(
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
|
||||
)
|
||||
outputs.append(
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
|
||||
)
|
||||
|
||||
# image latent inputs are modified in place (patchified and batch-expanded)
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (patchified and batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
# additional batch inputs (batch-expanded only)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_input_name = input_param.name
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
@@ -331,7 +448,8 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_name = input_param.name
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
@@ -349,20 +467,76 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
|
||||
"""
|
||||
Input processing step for Edit Plus that:
|
||||
1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch
|
||||
2. For additional batch inputs: Expands batch dimensions to match final batch size
|
||||
Height/width defaults to last image in the list.
|
||||
|
||||
Configured inputs:
|
||||
- Image latent inputs: ['image_latents']
|
||||
|
||||
This block should be placed after the encoder steps and the text input step.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`List`):
|
||||
The image heights calculated from the image latents dimension
|
||||
image_width (`List`):
|
||||
The image widths calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified,
|
||||
concatenated, and batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
image_latent_inputs: Optional[List[InputParam]] = None,
|
||||
additional_batch_inputs: Optional[List[InputParam]] = None,
|
||||
):
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
@@ -381,9 +555,9 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
@@ -398,23 +572,20 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
# default is `image_latents`
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
outputs = [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=List[int],
|
||||
@@ -427,11 +598,43 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
),
|
||||
]
|
||||
|
||||
# `height`/`width` are updated if any image latent inputs are provided
|
||||
if len(self._image_latent_inputs) > 0:
|
||||
outputs.append(
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
|
||||
)
|
||||
outputs.append(
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
|
||||
)
|
||||
|
||||
# image latent inputs are modified in place (patchified, concatenated, and batch-expanded)
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (patchified, concatenated, and batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
# additional batch inputs (batch-expanded only)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_input_name = input_param.name
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
@@ -476,7 +679,8 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_name = input_param.name
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
@@ -494,22 +698,75 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# YiYi TODO: support define config default component from the ModularPipeline level.
|
||||
# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
|
||||
# same as QwenImageAdditionalInputsStep, but with layered pachifier.
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
|
||||
"""
|
||||
Input processing step for Layered that:
|
||||
1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch
|
||||
size
|
||||
2. For additional batch inputs: Expands batch dimensions to match final batch size
|
||||
|
||||
Configured inputs:
|
||||
- Image latent inputs: ['image_latents']
|
||||
|
||||
This block should be placed after the encoder steps and the text input step.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified
|
||||
with layered pachifier and batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
image_latent_inputs: Optional[List[InputParam]] = None,
|
||||
additional_batch_inputs: Optional[List[InputParam]] = None,
|
||||
):
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
@@ -527,9 +784,9 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
@@ -544,21 +801,18 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("batch_size"),
|
||||
]
|
||||
# default is `image_latents`
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
outputs = [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
@@ -569,15 +823,44 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The image width calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(name="height", type_hint=int, description="The height of the image output"),
|
||||
OutputParam(name="width", type_hint=int, description="The width of the image output"),
|
||||
]
|
||||
|
||||
if len(self._image_latent_inputs) > 0:
|
||||
outputs.append(
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
|
||||
)
|
||||
outputs.append(
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
|
||||
)
|
||||
|
||||
# Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded)
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (patchified with layered pachifier and batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
# Add outputs for additional batch inputs (batch-expanded only)
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(
|
||||
OutputParam(
|
||||
name=input_param.name,
|
||||
type_hint=input_param.type_hint,
|
||||
description=input_param.description + " (batch-expanded)",
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_input_name = input_param.name
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
@@ -608,7 +891,8 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_name = input_param.name
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
@@ -626,7 +910,34 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
"""
|
||||
prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps.
|
||||
|
||||
Inputs:
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
Outputs:
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents (patchified and batch-expanded).
|
||||
height (`int`):
|
||||
if not provided, updated to control image height
|
||||
width (`int`):
|
||||
if not provided, updated to control image width
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
@@ -636,11 +947,28 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="control_image_latents", required=True),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(
|
||||
name="control_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.",
|
||||
),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("num_images_per_prompt"),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="control_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image latents (patchified and batch-expanded).",
|
||||
),
|
||||
OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"),
|
||||
OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -12,14 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
@@ -59,11 +56,91 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. VAE ENCODER
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAutoTextEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block.
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextEncoderStep()]
|
||||
block_names = ["text_encoder"]
|
||||
block_trigger_inputs = ["prompt"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block."
|
||||
" - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided."
|
||||
" - if `prompt` is not provided, step will be skipped."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step is used for processing image and mask inputs for inpainting tasks. It:
|
||||
- Resizes the image to the target size, based on `height` and `width`.
|
||||
- Processes and updates `image` and `mask_image`.
|
||||
- Creates `image_latents`.
|
||||
|
||||
Components:
|
||||
image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
@@ -78,7 +155,31 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that preprocess andencode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
@@ -89,7 +190,6 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
@@ -107,7 +207,33 @@ class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# optional controlnet vae encoder
|
||||
# auto_docstring
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
This is an auto pipeline block.
|
||||
- `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.
|
||||
- if `control_image` is not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor
|
||||
(`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
control_image (`Image`, *optional*):
|
||||
Control image for ControlNet conditioning.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
control_image_latents (`Tensor`):
|
||||
The latents representing the control image
|
||||
"""
|
||||
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
@@ -123,14 +249,65 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the img2img denoising step. It:
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
@@ -140,12 +317,69 @@ class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the inpainting denoising step. It:
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image (batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
additional_batch_inputs=[
|
||||
InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")
|
||||
]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
@@ -158,7 +392,42 @@ class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
# auto_docstring
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:
|
||||
- Add noise to the image latents to create the latents input for the denoiser.
|
||||
- Create the pachified latents `mask` based on the processedmask image.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
|
||||
generated from vae encoder and updated in input step.)
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
latents (`Tensor`):
|
||||
The scaled noisy latents to use for inpainting/image-to-image denoising.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
@@ -176,7 +445,49 @@ class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image (text2image)
|
||||
# auto_docstring
|
||||
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs
|
||||
(timesteps, latents, rope inputs etc.).
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -199,9 +510,63 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (inpainting)
|
||||
# auto_docstring
|
||||
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
@@ -226,9 +591,61 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (image2image)
|
||||
# auto_docstring
|
||||
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
@@ -253,9 +670,66 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (text2image) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs
|
||||
(timesteps, latents, rope inputs etc.).
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
|
||||
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -282,9 +756,72 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
|
||||
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
@@ -313,9 +850,70 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image (image2image) with controlnet
|
||||
# auto_docstring
|
||||
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img
|
||||
task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
|
||||
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
control_image_latents (`Tensor`):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
@@ -344,6 +942,12 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Auto denoise step for QwenImage
|
||||
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@@ -402,19 +1006,36 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DECODE
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
# standard decode step works for most tasks except for inpaint
|
||||
# auto_docstring
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -425,7 +1046,30 @@ class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
# auto_docstring
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask
|
||||
overally to the original image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -452,11 +1096,11 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("text_encoder", QwenImageAutoTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("denoise", QwenImageAutoCoreDenoiseStep()),
|
||||
@@ -465,7 +1109,89 @@ AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
- for image-to-image generation, you need to provide `image`
|
||||
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
|
||||
- to run the controlnet workflow, you need to provide `control_image`
|
||||
- for text-to-image generation, all you need to provide is `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`)
|
||||
control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
mask_image (`Image`, *optional*):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
control_image (`Image`, *optional*):
|
||||
Control image for ControlNet conditioning.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
control_image_latents (`Tensor`, *optional*):
|
||||
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
|
||||
step.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
@@ -476,7 +1202,7 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
@@ -484,5 +1210,5 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@@ -12,14 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
@@ -59,8 +58,35 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts."""
|
||||
"""
|
||||
QwenImage-Edit VL encoder step that encode the image and text prompts together.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
@@ -80,7 +106,30 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
@@ -95,12 +144,46 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit Inpaint VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:
|
||||
- resize the image for target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
- process the resized image and mask image.
|
||||
- create image latents.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditInpaintProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@@ -137,11 +220,64 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the edit denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
QwenImageAdditionalInputsStep(),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@@ -154,12 +290,71 @@ class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the edit inpaint denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
|
||||
batch-expanded)
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image (batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
additional_batch_inputs=[
|
||||
InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")
|
||||
]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
@@ -174,7 +369,42 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:
|
||||
- Add noise to the image latents to create the latents input for the denoiser.
|
||||
- Create the patchified latents `mask` based on the processed mask image.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
|
||||
generated from vae encoder and updated in input step.)
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
dtype (`dtype`, *optional*, defaults to torch.float32):
|
||||
The dtype of the model inputs, can be generated in input step.
|
||||
|
||||
Outputs:
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
latents (`Tensor`):
|
||||
The scaled noisy latents to use for inpainting/image-to-image denoising.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
@@ -189,7 +419,50 @@ class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Edit edit (img2img) task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInputStep(),
|
||||
@@ -212,9 +485,62 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Edit edit inpaint task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintInputStep(),
|
||||
@@ -239,6 +565,12 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Auto core denoise step for QwenImage Edit
|
||||
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@@ -267,6 +599,12 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
@@ -274,7 +612,26 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
|
||||
|
||||
# Decode step (standard)
|
||||
# auto_docstring
|
||||
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -285,7 +642,30 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask
|
||||
overlay to the original image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -313,9 +693,7 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
@@ -333,7 +711,66 @@ EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.
|
||||
- for edit (img2img) generation, you need to provide `image`
|
||||
- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide
|
||||
`padding_mask_crop`
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
mask_image (`Image`, *optional*):
|
||||
Mask image for inpainting.
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
processed_mask_image (`Tensor`, *optional*):
|
||||
The processed mask image
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
@@ -349,5 +786,5 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@@ -12,11 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
@@ -53,12 +48,41 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
|
||||
"""
|
||||
QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
|
||||
QwenImageEditPlusResizeStep(),
|
||||
QwenImageEditPlusTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
@@ -73,12 +97,36 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
|
||||
"""
|
||||
VAE encoder step that encodes image inputs into latent representations.
|
||||
Each image is resized independently based on its own aspect ratio to 1024x1024 target area.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
|
||||
QwenImageEditPlusResizeStep(),
|
||||
QwenImageEditPlusProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
@@ -98,11 +146,66 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the Edit Plus denoising step. It:
|
||||
- Standardizes text embeddings batch size.
|
||||
- Processes list of image latents: patchifies, concatenates along dim=1, expands batch.
|
||||
- Outputs lists of image_height/image_width for RoPE calculation.
|
||||
- Defaults height/width from last image in the list.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`List`):
|
||||
The image heights calculated from the image latents dimension
|
||||
image_width (`List`):
|
||||
The image widths calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified,
|
||||
concatenated, and batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
QwenImageEditPlusAdditionalInputsStep(),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@@ -118,7 +221,50 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Edit Plus (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Edit Plus edit (img2img) task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusInputStep(),
|
||||
@@ -144,9 +290,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
@@ -155,7 +299,26 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Decode step that decodes the latents to images and postprocesses the generated image.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
latents (`Tensor`):
|
||||
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -179,7 +342,53 @@ EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.
|
||||
- `image` is required input (can be single image or list of images).
|
||||
- Each image is resized independently based on its own aspect ratio.
|
||||
- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
|
||||
@@ -196,5 +405,5 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@@ -12,12 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
@@ -55,8 +49,44 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
|
||||
"""
|
||||
QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not
|
||||
provided.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
use_en_prompt (`bool`, *optional*, defaults to False):
|
||||
Whether to use English prompt template
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation. If not provided, updated using image caption
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
@@ -77,7 +107,32 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
@@ -98,11 +153,60 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the layered denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
The batch size of the prompt embeddings
|
||||
dtype (`dtype`):
|
||||
The data type of the prompt embeddings
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings. (batch-expanded)
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask. (batch-expanded)
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
width (`int`):
|
||||
if not provided, updated to image width
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified
|
||||
with layered pachifier and batch-expanded)
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
QwenImageLayeredAdditionalInputsStep(),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@@ -116,7 +220,48 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoising workflow for QwenImage-Layered img2img task.
|
||||
|
||||
Components:
|
||||
pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredInputStep(),
|
||||
@@ -142,9 +287,7 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
@@ -162,7 +305,54 @@ LAYERED_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for layered denoising tasks using QwenImage-Layered.
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`)
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
use_en_prompt (`bool`, *optional*, defaults to False):
|
||||
Whether to use English prompt template
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = LAYERED_AUTO_BLOCKS.values()
|
||||
block_names = LAYERED_AUTO_BLOCKS.keys()
|
||||
@@ -174,5 +364,5 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@@ -131,7 +131,7 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
|
||||
@@ -482,8 +482,6 @@ class ChromaInpaintPipeline(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
max_sequence_length=None,
|
||||
@@ -531,15 +529,6 @@ class ChromaInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||
|
||||
@@ -793,13 +782,11 @@ class ChromaInpaintPipeline(
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
||||
@@ -52,6 +52,15 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
@@ -359,7 +368,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
@@ -549,6 +558,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
num_latent_conditional_frames: int = 2,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
@@ -614,6 +624,10 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
num_latent_conditional_frames (`int`, defaults to `2`):
|
||||
Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames
|
||||
extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1
|
||||
for Image2World-like behavior (single frame conditioning).
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -692,19 +706,38 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
if num_latent_conditional_frames not in [1, 2]:
|
||||
raise ValueError(
|
||||
f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}"
|
||||
)
|
||||
|
||||
frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1
|
||||
|
||||
total_input_frames = len(video)
|
||||
|
||||
if total_input_frames < frames_to_extract:
|
||||
raise ValueError(
|
||||
f"Input video has only {total_input_frames} frames but Video2World requires at least "
|
||||
f"{frames_to_extract} frames for conditioning."
|
||||
)
|
||||
|
||||
num_frames_in = frames_to_extract
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
# For Video2World: extract last frames_to_extract frames from input, then pad
|
||||
if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]:
|
||||
video = video[:, :, -num_frames_in:, :, :]
|
||||
|
||||
num_frames_out = num_frames
|
||||
|
||||
if video.shape[2] < num_frames_out:
|
||||
n_pad_frames = num_frames_out - num_frames_in
|
||||
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
|
||||
n_pad_frames = num_frames_out - video.shape[2]
|
||||
last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W]
|
||||
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
|
||||
video = torch.cat((video, pad_frames), dim=2)
|
||||
|
||||
|
||||
@@ -49,6 +49,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -300,7 +308,7 @@ class Cosmos2TextToImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -50,6 +50,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -319,7 +327,7 @@ class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -49,6 +49,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -285,7 +293,7 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -50,6 +50,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -331,7 +339,7 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -260,25 +260,115 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
token_ids = token_ids.view(1, -1)
|
||||
return token_ids
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_normalize_images(
|
||||
image: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]],
|
||||
batch_size: int,
|
||||
) -> Optional[List[List[PIL.Image.Image]]]:
|
||||
"""
|
||||
Validate and normalize image inputs to List[List[PIL.Image]].
|
||||
|
||||
Rules:
|
||||
- batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length
|
||||
- batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]])
|
||||
- Other formats raise ValueError
|
||||
|
||||
Args:
|
||||
image: Input images in various formats
|
||||
batch_size: Number of prompts in the batch
|
||||
|
||||
Returns:
|
||||
Normalized images as List[List[PIL.Image]], or None if no images provided
|
||||
"""
|
||||
if image is None or len(image) == 0:
|
||||
return None
|
||||
|
||||
first_element = image[0]
|
||||
|
||||
if batch_size == 1:
|
||||
# Legacy format: List[PIL.Image] -> [[img1, img2, ...]]
|
||||
if not isinstance(first_element, (list, tuple)):
|
||||
return [list(image)]
|
||||
# Already in List[List[PIL.Image]] format
|
||||
if len(image) != 1:
|
||||
raise ValueError(
|
||||
f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}."
|
||||
)
|
||||
return [list(image[0])]
|
||||
|
||||
# batch_size > 1: must be List[List[PIL.Image]]
|
||||
if not isinstance(first_element, (list, tuple)):
|
||||
raise ValueError(
|
||||
f"For batch_size > 1, images must be List[List[PIL.Image]] format. "
|
||||
f"Got List[{type(first_element).__name__}] instead. "
|
||||
f"Each prompt requires its own list of condition images."
|
||||
)
|
||||
|
||||
if len(image) != batch_size:
|
||||
raise ValueError(f"Number of image lists ({len(image)}) must match batch size ({batch_size}).")
|
||||
|
||||
# Validate homogeneous: all sublists must have same length
|
||||
num_input_images_per_prompt = len(image[0])
|
||||
for idx, imgs in enumerate(image):
|
||||
if len(imgs) != num_input_images_per_prompt:
|
||||
raise ValueError(
|
||||
f"All prompts must have the same number of condition images. "
|
||||
f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images."
|
||||
)
|
||||
|
||||
return [list(imgs) for imgs in image]
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
image: Optional[List[List[PIL.Image.Image]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
"""
|
||||
Generate prior tokens for the DiT model using the AR model.
|
||||
|
||||
Args:
|
||||
prompt: Single prompt or list of prompts
|
||||
height: Target image height
|
||||
width: Target image width
|
||||
image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated
|
||||
using _validate_and_normalize_images() before calling this method.
|
||||
device: Target device
|
||||
generator: Random generator for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens
|
||||
- prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains
|
||||
the upsampled prior token ids for all condition images in that sample. None for t2i.
|
||||
- source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape
|
||||
(num_condition_images, 3) with upsampled grid info. None for t2i.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
# Normalize prompt to list format
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt_list)
|
||||
|
||||
# Image is already normalized by _validate_and_normalize_images(): None or List[List[PIL.Image]]
|
||||
is_text_to_image = image is None
|
||||
# Build messages for each sample in the batch
|
||||
all_messages = []
|
||||
for idx, p in enumerate(prompt_list):
|
||||
content = []
|
||||
if not is_text_to_image:
|
||||
for img in image[idx]:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": p})
|
||||
all_messages.append([{"role": "user", "content": content}])
|
||||
# Process with the processor (supports batch with left padding)
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
all_messages,
|
||||
tokenize=True,
|
||||
padding=True if batch_size > 1 else False,
|
||||
target_h=height,
|
||||
target_w=width,
|
||||
return_dict=True,
|
||||
@@ -286,44 +376,117 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
).to(device)
|
||||
|
||||
image_grid_thw = inputs.get("image_grid_thw")
|
||||
images_per_sample = inputs.get("images_per_sample")
|
||||
|
||||
# Determine number of condition images and grids per sample
|
||||
num_condition_images = 0 if is_text_to_image else len(image[0])
|
||||
if images_per_sample is not None:
|
||||
num_grids_per_sample = images_per_sample[0].item()
|
||||
else:
|
||||
# Fallback for batch_size=1: total grids is for single sample
|
||||
num_grids_per_sample = image_grid_thw.shape[0]
|
||||
|
||||
# Compute generation params (same for all samples in homogeneous batch)
|
||||
first_sample_grids = image_grid_thw[:num_grids_per_sample]
|
||||
max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
|
||||
image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
|
||||
image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image
|
||||
)
|
||||
|
||||
# Generate source image tokens (prior_token_image_ids) for i2i mode
|
||||
prior_token_image_ids = None
|
||||
if image is not None:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], image_grid_thw[:-1]
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, image_grid_thw[:-1]
|
||||
)
|
||||
source_image_grid_thw = None
|
||||
if not is_text_to_image:
|
||||
# Extract source grids by selecting condition image indices (skip target grids)
|
||||
# Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...]
|
||||
# We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...]
|
||||
source_indices = []
|
||||
for sample_idx in range(batch_size):
|
||||
base = sample_idx * num_grids_per_sample
|
||||
source_indices.extend(range(base, base + num_condition_images))
|
||||
source_grids = image_grid_thw[source_indices]
|
||||
|
||||
# For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
|
||||
# max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
|
||||
if len(source_grids) > 0:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], source_grids, return_dict=False
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, source_grids
|
||||
)
|
||||
# Upsample each source image's prior tokens to match VAE/DiT resolution
|
||||
split_sizes = source_grids.prod(dim=-1).tolist()
|
||||
prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes)
|
||||
upsampled_prior_ids = []
|
||||
for i, prior_ids in enumerate(prior_ids_per_source):
|
||||
t, h, w = source_grids[i].tolist()
|
||||
upsampled = self._upsample_token_ids(prior_ids, int(h), int(w))
|
||||
upsampled_prior_ids.append(upsampled.squeeze(0))
|
||||
prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0)
|
||||
# Upsample grid dimensions for later splitting
|
||||
upsampled_grids = source_grids.clone()
|
||||
upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2
|
||||
upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2
|
||||
source_image_grid_thw = upsampled_grids
|
||||
|
||||
# Generate with AR model
|
||||
# Set torch random seed from generator for reproducibility
|
||||
# (transformers generate() doesn't accept generator parameter)
|
||||
if generator is not None:
|
||||
seed = generator.initial_seed()
|
||||
torch.manual_seed(seed)
|
||||
if device is not None and device.type == "cuda":
|
||||
torch.cuda.manual_seed(seed)
|
||||
outputs = self.vision_language_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
# Extract and upsample prior tokens for each sample
|
||||
# For left-padded inputs, generated tokens start after the padded input sequence
|
||||
all_prior_token_ids = []
|
||||
max_input_length = inputs["input_ids"].shape[-1]
|
||||
for idx in range(batch_size):
|
||||
# For left-padded sequences, generated tokens start at max_input_length
|
||||
# (padding is on the left, so all sequences end at the same position)
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs[idx : idx + 1], max_input_length, large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
all_prior_token_ids.append(prior_token_ids)
|
||||
prior_token_ids = torch.cat(all_prior_token_ids, dim=0)
|
||||
|
||||
return prior_token_ids, prior_token_image_ids
|
||||
# Split prior_token_image_ids and source_image_grid_thw into per-sample lists for easier consumption
|
||||
prior_token_image_ids_per_sample = None
|
||||
source_image_grid_thw_per_sample = None
|
||||
if prior_token_image_ids is not None and source_image_grid_thw is not None:
|
||||
# Split grids: each sample has num_condition_images grids
|
||||
source_image_grid_thw_per_sample = list(torch.split(source_image_grid_thw, num_condition_images))
|
||||
# Split prior_token_image_ids: tokens per sample may vary due to different image sizes
|
||||
tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist()
|
||||
tokens_per_sample = []
|
||||
for i in range(batch_size):
|
||||
start_idx = i * num_condition_images
|
||||
end_idx = start_idx + num_condition_images
|
||||
tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx]))
|
||||
prior_token_image_ids_per_sample = list(torch.split(prior_token_image_ids, tokens_per_sample))
|
||||
|
||||
return prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample
|
||||
|
||||
def get_glyph_texts(self, prompt):
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", prompt)
|
||||
+ re.findall(r"“([^“”]*)”", prompt)
|
||||
+ re.findall(r'"([^"]*)"', prompt)
|
||||
+ re.findall(r"「([^「」]*)」", prompt)
|
||||
)
|
||||
return ocr_texts
|
||||
"""Extract glyph texts from prompt(s). Returns a list of lists for batch processing."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
all_ocr_texts = []
|
||||
for p in prompt:
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", p)
|
||||
+ re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p)
|
||||
+ re.findall(r'"([^"]*)"', p)
|
||||
+ re.findall(r"「([^「」]*)」", p)
|
||||
)
|
||||
all_ocr_texts.append(ocr_texts)
|
||||
return all_ocr_texts
|
||||
|
||||
def _get_glyph_embeds(
|
||||
self,
|
||||
@@ -332,29 +495,51 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""Get glyph embeddings for each prompt in the batch."""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts if len(glyph_texts) > 0 else [""],
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
# get_glyph_texts now returns a list of lists (one per prompt)
|
||||
all_glyph_texts = self.get_glyph_texts(prompt)
|
||||
|
||||
all_glyph_embeds = []
|
||||
for glyph_texts in all_glyph_texts:
|
||||
if len(glyph_texts) == 0:
|
||||
glyph_texts = [""]
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_))
|
||||
for input_ids_ in input_ids
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
all_glyph_embeds.append(glyph_embeds)
|
||||
|
||||
# Pad to same sequence length and stack (use left padding to match transformers)
|
||||
max_seq_len = max(emb.size(1) for emb in all_glyph_embeds)
|
||||
padded_embeds = []
|
||||
for emb in all_glyph_embeds:
|
||||
if emb.size(1) < max_seq_len:
|
||||
pad = torch.zeros(emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype)
|
||||
emb = torch.cat([pad, emb], dim=1) # left padding
|
||||
padded_embeds.append(emb)
|
||||
|
||||
glyph_embeds = torch.cat(padded_embeds, dim=0)
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
def encode_prompt(
|
||||
@@ -399,9 +584,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
# Repeat embeddings for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# For GLM-Image, negative_prompt must be "" instead of None
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@@ -409,9 +594,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
if num_images_per_prompt > 1:
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@@ -442,7 +626,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prior_token_ids=None,
|
||||
prior_image_token_ids=None,
|
||||
prior_token_image_ids=None,
|
||||
source_image_grid_thw=None,
|
||||
image=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
@@ -488,12 +674,24 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if (prior_token_ids is None and prior_image_token_ids is not None) or (
|
||||
prior_token_ids is not None and prior_image_token_ids is None
|
||||
):
|
||||
# Validate prior token inputs: for i2i mode, all three must be provided together
|
||||
# For t2i mode, only prior_token_ids is needed (prior_token_image_ids and source_image_grid_thw should be None)
|
||||
prior_image_inputs = [prior_token_image_ids, source_image_grid_thw]
|
||||
num_prior_image_inputs = sum(x is not None for x in prior_image_inputs)
|
||||
if num_prior_image_inputs > 0 and num_prior_image_inputs < len(prior_image_inputs):
|
||||
raise ValueError(
|
||||
f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
|
||||
f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
|
||||
"`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. "
|
||||
f"Got prior_token_image_ids={prior_token_image_ids is not None}, "
|
||||
f"source_image_grid_thw={source_image_grid_thw is not None}."
|
||||
)
|
||||
if num_prior_image_inputs > 0 and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided."
|
||||
)
|
||||
if num_prior_image_inputs > 0 and image is None:
|
||||
raise ValueError(
|
||||
"`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided "
|
||||
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
@@ -545,7 +743,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prior_token_ids: Optional[torch.FloatTensor] = None,
|
||||
prior_image_token_ids: Optional[torch.Tensor] = None,
|
||||
prior_token_image_ids: Optional[List[torch.Tensor]] = None,
|
||||
source_image_grid_thw: Optional[List[torch.Tensor]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -598,7 +797,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prior_token_ids,
|
||||
prior_image_token_ids,
|
||||
prior_token_image_ids,
|
||||
source_image_grid_thw,
|
||||
image,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
@@ -611,34 +812,47 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Preprocess image tokens and prompt tokens
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
)
|
||||
# 2. Validate and normalize image format
|
||||
normalized_image = self._validate_and_normalize_images(image, batch_size)
|
||||
|
||||
# 3. Preprocess image
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
image = preprocessed_condition_images
|
||||
# 3. Generate prior tokens (batch mode)
|
||||
# Get a single generator for AR model (use first if list provided)
|
||||
ar_generator = generator[0] if isinstance(generator, list) else generator
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample = (
|
||||
self.generate_prior_tokens(
|
||||
prompt=prompt,
|
||||
image=normalized_image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
generator=ar_generator,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User provided prior_token_ids directly (from generate_prior_tokens)
|
||||
prior_token_image_ids_per_sample = prior_token_image_ids
|
||||
source_image_grid_thw_per_sample = source_image_grid_thw
|
||||
|
||||
# 4. Preprocess images for VAE encoding
|
||||
preprocessed_images = None
|
||||
if normalized_image is not None:
|
||||
preprocessed_images = []
|
||||
for prompt_images in normalized_image:
|
||||
prompt_preprocessed = []
|
||||
for img in prompt_images:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
prompt_preprocessed.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
preprocessed_images.append(prompt_preprocessed)
|
||||
|
||||
# 5. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
@@ -652,7 +866,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare latents and (optional) image kv cache
|
||||
# 6. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
@@ -666,7 +880,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if image is not None:
|
||||
if normalized_image is not None:
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
@@ -674,29 +888,38 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
# Process each sample's condition images
|
||||
for prompt_idx in range(batch_size):
|
||||
prompt_images = preprocessed_images[prompt_idx]
|
||||
prompt_prior_ids = prior_token_image_ids_per_sample[prompt_idx]
|
||||
prompt_grid_thw = source_image_grid_thw_per_sample[prompt_idx]
|
||||
|
||||
# Do not remove.
|
||||
# It would be use to run the reference image through a
|
||||
# forward pass at timestep 0 and keep the KV cache.
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
# Split this sample's prior_token_image_ids by each image's token count
|
||||
split_sizes = prompt_grid_thw.prod(dim=-1).tolist()
|
||||
prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes)
|
||||
# Process each condition image for this sample
|
||||
for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
|
||||
# 6. Prepare additional timestep conditions
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
# Move to next sample's cache slot
|
||||
kv_caches.next_sample()
|
||||
|
||||
# 7. Prepare additional timestep conditions
|
||||
target_size = (height, width)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
@@ -726,10 +949,13 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 8. Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# Repeat prior_token_ids for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
|
||||
prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -742,7 +968,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if image is not None:
|
||||
if prior_token_image_ids_per_sample is not None:
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
@@ -760,7 +986,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if image is not None:
|
||||
if prior_token_image_ids_per_sample is not None:
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
|
||||
@@ -254,13 +254,17 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
@@ -307,15 +311,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
|
||||
@@ -321,8 +321,13 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
@@ -369,15 +374,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
|
||||
@@ -305,6 +305,9 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -309,6 +309,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -321,6 +321,9 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
|
||||
@@ -375,14 +378,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
|
||||
@@ -323,6 +323,9 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
||||
|
||||
@@ -265,7 +265,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -297,13 +297,17 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
@@ -354,15 +358,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -308,13 +308,17 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
@@ -369,14 +373,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
|
||||
@@ -320,13 +320,17 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
|
||||
@@ -623,7 +623,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
|
||||
if is_torchao_available():
|
||||
# TODO(aryan): Support autoquant and sparsify
|
||||
# TODO(aryan): Support sparsify
|
||||
from torchao.quantization import (
|
||||
float8_dynamic_activation_float8_weight,
|
||||
float8_static_activation_float8_weight,
|
||||
|
||||
@@ -344,7 +344,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
|
||||
@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def previous_timestep(self, timestep: int) -> int:
|
||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
||||
"""
|
||||
Compute the previous timestep in the diffusion chain.
|
||||
|
||||
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
For more details, see the original paper: https://huggingface.co/papers/2006.11239
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`, defaults to `"fixed_small"`):
|
||||
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Option to clip predicted sample for numerical stability.
|
||||
prediction_type (`str`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://huggingface.co/papers/2210.02303)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
|
||||
diffusion models (such as stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, default `"leading"`):
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, default `0`):
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
@@ -293,7 +296,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
@@ -478,7 +481,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -490,7 +493,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
@@ -503,7 +507,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
||||
"learned",
|
||||
"learned_range",
|
||||
]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
@@ -552,7 +559,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
if t > 0:
|
||||
device = model_output.device
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
if self.variance_type == "fixed_small_log":
|
||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||
@@ -575,7 +585,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
timesteps: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -588,8 +598,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`List[int]`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
timesteps (`torch.Tensor`):
|
||||
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
@@ -603,7 +613,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
||||
"learned",
|
||||
"learned_range",
|
||||
]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
pass
|
||||
@@ -734,7 +747,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
||||
"""
|
||||
Compute the previous timestep in the diffusion chain.
|
||||
|
||||
@@ -743,7 +756,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
@@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from
|
||||
`leading`, `linspace` or `trailing`.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
snr_shift_scale (`float`, defaults to 3.0):
|
||||
Shift scale for SNR.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -191,15 +193,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.0120,
|
||||
beta_schedule: str = "scaled_linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
snr_shift_scale: float = 3.0,
|
||||
):
|
||||
@@ -209,7 +211,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -266,13 +276,20 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not
|
||||
moved.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
@@ -311,7 +328,27 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
|
||||
def get_variables(
|
||||
self,
|
||||
alpha_prod_t: torch.Tensor,
|
||||
alpha_prod_t_prev: torch.Tensor,
|
||||
alpha_prod_t_back: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the variables used for DPM-Solver++ (2M) referencing the original implementation.
|
||||
|
||||
Args:
|
||||
alpha_prod_t (`torch.Tensor`):
|
||||
The cumulative product of alphas at the current timestep.
|
||||
alpha_prod_t_prev (`torch.Tensor`):
|
||||
The cumulative product of alphas at the previous timestep.
|
||||
alpha_prod_t_back (`torch.Tensor`, *optional*):
|
||||
The cumulative product of alphas at the timestep before the previous timestep.
|
||||
|
||||
Returns:
|
||||
`tuple`:
|
||||
A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`.
|
||||
"""
|
||||
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
|
||||
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
|
||||
h = lamb_next - lamb
|
||||
@@ -324,7 +361,36 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
|
||||
def get_mult(
|
||||
self,
|
||||
h: torch.Tensor,
|
||||
r: Optional[torch.Tensor],
|
||||
alpha_prod_t: torch.Tensor,
|
||||
alpha_prod_t_prev: torch.Tensor,
|
||||
alpha_prod_t_back: Optional[torch.Tensor] = None,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]:
|
||||
"""
|
||||
Compute the multipliers for the previous sample and the predicted original sample.
|
||||
|
||||
Args:
|
||||
h (`torch.Tensor`):
|
||||
The log-SNR difference.
|
||||
r (`torch.Tensor`):
|
||||
The ratio of log-SNR differences.
|
||||
alpha_prod_t (`torch.Tensor`):
|
||||
The cumulative product of alphas at the current timestep.
|
||||
alpha_prod_t_prev (`torch.Tensor`):
|
||||
The cumulative product of alphas at the previous timestep.
|
||||
alpha_prod_t_back (`torch.Tensor`, *optional*):
|
||||
The cumulative product of alphas at the timestep before the previous timestep.
|
||||
|
||||
Returns:
|
||||
`tuple`:
|
||||
A tuple containing the multipliers.
|
||||
"""
|
||||
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
|
||||
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
|
||||
|
||||
@@ -338,13 +404,13 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
old_pred_original_sample: torch.Tensor,
|
||||
old_pred_original_sample: Optional[torch.Tensor],
|
||||
timestep: int,
|
||||
timestep_back: int,
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
@@ -355,8 +421,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
old_pred_original_sample (`torch.Tensor`):
|
||||
The predicted original sample from the previous timestep.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
timestep_back (`int`):
|
||||
The timestep to look back to.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
@@ -436,7 +506,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return prev_sample, pred_original_sample
|
||||
else:
|
||||
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
|
||||
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||
noise = randn_tensor(
|
||||
sample.shape,
|
||||
generator=generator,
|
||||
device=sample.device,
|
||||
dtype=sample.dtype,
|
||||
)
|
||||
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
|
||||
|
||||
prev_sample = x_advanced
|
||||
@@ -524,5 +599,5 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -31,6 +32,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DPMSolverMultistepSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -171,6 +175,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
|
||||
@@ -203,7 +211,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple,
|
||||
) -> DPMSolverMultistepSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -301,10 +312,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://huggingface.co/papers/2205.11487
|
||||
dynamic_max_val = jnp.percentile(
|
||||
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
|
||||
jnp.abs(x0_pred),
|
||||
self.config.dynamic_thresholding_ratio,
|
||||
axis=tuple(range(1, x0_pred.ndim)),
|
||||
)
|
||||
dynamic_max_val = jnp.maximum(
|
||||
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * jnp.ones_like(dynamic_max_val),
|
||||
)
|
||||
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
return x0_pred
|
||||
@@ -385,7 +399,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
|
||||
lambda_t, lambda_s0, lambda_s1 = (
|
||||
state.lambda_t[t],
|
||||
state.lambda_t[s0],
|
||||
state.lambda_t[s1],
|
||||
)
|
||||
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
@@ -443,7 +461,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
Returns:
|
||||
`jnp.ndarray`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
t, s0, s1, s2 = (
|
||||
prev_timestep,
|
||||
timestep_list[-1],
|
||||
timestep_list[-2],
|
||||
timestep_list[-3],
|
||||
)
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
state.lambda_t[t],
|
||||
@@ -615,7 +638,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
|
||||
@@ -19,6 +19,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -28,6 +29,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class EulerDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
sigmas: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
return cls(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -99,6 +112,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
|
||||
@@ -146,7 +163,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
self,
|
||||
state: EulerDiscreteSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
) -> EulerDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -159,7 +179,12 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
timesteps = jnp.linspace(
|
||||
self.config.num_train_timesteps - 1,
|
||||
0,
|
||||
num_inference_steps,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
|
||||
@@ -22,10 +22,13 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, logging
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class KarrasVeSchedulerState:
|
||||
# setable values
|
||||
@@ -102,7 +105,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
s_min: float = 0.05,
|
||||
s_max: float = 50,
|
||||
):
|
||||
pass
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
def create_state(self):
|
||||
return KarrasVeSchedulerState.create()
|
||||
|
||||
@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -20,6 +20,7 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -29,6 +30,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class LMSDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
sigmas: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
return cls(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -101,6 +114,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
|
||||
@@ -165,7 +182,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
self,
|
||||
state: LMSDiscreteSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
) -> LMSDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -177,7 +197,12 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
timesteps = jnp.linspace(
|
||||
self.config.num_train_timesteps - 1,
|
||||
0,
|
||||
num_inference_steps,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
low_idx = jnp.floor(timesteps).astype(jnp.int32)
|
||||
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -31,6 +32,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class PNDMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -131,6 +135,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
@@ -190,7 +198,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
else:
|
||||
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
|
||||
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
|
||||
jnp.array(
|
||||
[0, self.config.num_train_timesteps // num_inference_steps // 2],
|
||||
dtype=jnp.int32,
|
||||
),
|
||||
self.pndm_order,
|
||||
)
|
||||
|
||||
@@ -218,7 +229,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -320,7 +334,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
diff_to_prev = jnp.where(
|
||||
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
|
||||
state.counter % 2,
|
||||
0,
|
||||
self.config.num_train_timesteps // state.num_inference_steps // 2,
|
||||
)
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = state.prk_timesteps[state.counter // 4 * 4]
|
||||
@@ -401,7 +417,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
|
||||
timestep = jnp.where(
|
||||
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
|
||||
state.counter == 1,
|
||||
timestep + self.config.num_train_timesteps // state.num_inference_steps,
|
||||
timestep,
|
||||
)
|
||||
|
||||
# Reference:
|
||||
@@ -466,7 +484,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
prev_timestep >= 0,
|
||||
state.common.alphas_cumprod[prev_timestep],
|
||||
state.final_alpha_cumprod,
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -23,7 +23,15 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -95,7 +103,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
):
|
||||
pass
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
def create_state(self):
|
||||
state = ScoreSdeVeSchedulerState.create()
|
||||
@@ -108,7 +119,11 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
|
||||
self,
|
||||
state: ScoreSdeVeSchedulerState,
|
||||
num_inference_steps: int,
|
||||
shape: Tuple = (),
|
||||
sampling_eps: float = None,
|
||||
) -> ScoreSdeVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -17,6 +17,51 @@ class Flux2AutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -276,3 +276,74 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
|
||||
# Test 1: Run with None mask (no padding, all tokens are valid)
|
||||
inputs_no_mask = inputs.copy()
|
||||
inputs_no_mask["encoder_hidden_states_mask"] = None
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_all_ones = model(**inputs_all_ones)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
|
||||
|
||||
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_with_padding = model(**inputs_with_padding)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -169,7 +169,7 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
|
||||
0.5849247, 0.50278825, 0.45747858, 0.45895284, 0.43804976, 0.47044256, 0.5239665, 0.47904694, 0.3323419, 0.38725388, 0.28505728, 0.3161863, 0.35026982, 0.37546024, 0.4090118, 0.46629113
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
@@ -177,20 +177,109 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(image.shape, (3, 32, 32))
|
||||
self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
"""Test that batch=1 produces consistent results with the same seed."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_consistent(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
# Run twice with same seed
|
||||
inputs1 = self.get_dummy_inputs(device, seed=42)
|
||||
inputs2 = self.get_dummy_inputs(device, seed=42)
|
||||
|
||||
image1 = pipe(**inputs1).images[0]
|
||||
image2 = pipe(**inputs2).images[0]
|
||||
|
||||
self.assertTrue(torch.allclose(image1, image2, atol=1e-4))
|
||||
|
||||
def test_inference_batch_multiple_prompts(self):
|
||||
"""Test batch processing with multiple prompts."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": ["A photo of a cat", "A photo of a dog"],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 2 images
|
||||
self.assertEqual(len(images), 2)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
self.assertEqual(images[1].shape, (3, 32, 32))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_num_images_per_prompt(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
"""Test generating multiple images per prompt."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": "A photo of a cat",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"num_images_per_prompt": 2,
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 2 images for single prompt
|
||||
self.assertEqual(len(images), 2)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
self.assertEqual(images[1].shape, (3, 32, 32))
|
||||
|
||||
def test_batch_with_num_images_per_prompt(self):
|
||||
"""Test batch prompts with num_images_per_prompt > 1."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": ["A photo of a cat", "A photo of a dog"],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"num_images_per_prompt": 2,
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 4 images (2 prompts × 2 images per prompt)
|
||||
self.assertEqual(len(images), 4)
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
|
||||
352
utils/modular_auto_docstring.py
Normal file
352
utils/modular_auto_docstring.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Auto Docstring Generator for Modular Pipeline Blocks
|
||||
|
||||
This script scans Python files for classes that have `# auto_docstring` comment above them
|
||||
and inserts/updates the docstring from the class's `doc` property.
|
||||
|
||||
Run from the root of the repo:
|
||||
python utils/modular_auto_docstring.py [path] [--fix_and_overwrite]
|
||||
|
||||
Examples:
|
||||
# Check for auto_docstring markers (will error if found without proper docstring)
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Check specific directory
|
||||
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/
|
||||
|
||||
# Fix and overwrite the docstrings
|
||||
python utils/modular_auto_docstring.py --fix_and_overwrite
|
||||
|
||||
Usage in code:
|
||||
# auto_docstring
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
# docstring will be automatically inserted here
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return "Your docstring content..."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo
|
||||
DIFFUSERS_PATH = "src/diffusers"
|
||||
REPO_PATH = "."
|
||||
|
||||
# Pattern to match the auto_docstring comment
|
||||
AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$")
|
||||
|
||||
|
||||
def setup_diffusers_import():
|
||||
"""Setup import path to use the local diffusers module."""
|
||||
src_path = os.path.join(REPO_PATH, "src")
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
|
||||
def get_module_from_filepath(filepath: str) -> str:
|
||||
"""Convert a filepath to a module name."""
|
||||
filepath = os.path.normpath(filepath)
|
||||
|
||||
if filepath.startswith("src" + os.sep):
|
||||
filepath = filepath[4:]
|
||||
|
||||
if filepath.endswith(".py"):
|
||||
filepath = filepath[:-3]
|
||||
|
||||
module_name = filepath.replace(os.sep, ".")
|
||||
return module_name
|
||||
|
||||
|
||||
def load_module(filepath: str):
|
||||
"""Load a module from filepath."""
|
||||
setup_diffusers_import()
|
||||
module_name = get_module_from_filepath(filepath)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not import module {module_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_doc_from_class(module, class_name: str) -> str:
|
||||
"""Get the doc property from an instantiated class."""
|
||||
if module is None:
|
||||
return None
|
||||
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
instance = cls()
|
||||
if hasattr(instance, "doc"):
|
||||
return instance.doc
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not instantiate {class_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_auto_docstring_classes(filepath: str) -> list:
|
||||
"""
|
||||
Find all classes in a file that have # auto_docstring comment above them.
|
||||
|
||||
Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line)
|
||||
"""
|
||||
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Parse AST to find class locations and their docstrings
|
||||
content = "".join(lines)
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in {filepath}: {e}")
|
||||
return []
|
||||
|
||||
# Build a map of class_name -> (class_line, has_docstring, docstring_end_line)
|
||||
class_info = {}
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
has_docstring = False
|
||||
docstring_end_line = node.lineno # default to class line
|
||||
|
||||
if node.body and isinstance(node.body[0], ast.Expr):
|
||||
first_stmt = node.body[0]
|
||||
if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str):
|
||||
has_docstring = True
|
||||
docstring_end_line = first_stmt.end_lineno or first_stmt.lineno
|
||||
|
||||
class_info[node.name] = (node.lineno, has_docstring, docstring_end_line)
|
||||
|
||||
# Now scan for # auto_docstring comments
|
||||
classes_to_update = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if AUTO_DOCSTRING_PATTERN.match(line):
|
||||
# Found the marker, look for class definition on next non-empty, non-comment line
|
||||
j = i + 1
|
||||
while j < len(lines):
|
||||
next_line = lines[j].strip()
|
||||
if next_line and not next_line.startswith("#"):
|
||||
break
|
||||
j += 1
|
||||
|
||||
if j < len(lines) and lines[j].strip().startswith("class "):
|
||||
# Extract class name
|
||||
match = re.match(r"class\s+(\w+)", lines[j].strip())
|
||||
if match:
|
||||
class_name = match.group(1)
|
||||
if class_name in class_info:
|
||||
class_line, has_docstring, docstring_end_line = class_info[class_name]
|
||||
classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line))
|
||||
|
||||
return classes_to_update
|
||||
|
||||
|
||||
def strip_class_name_line(doc: str, class_name: str) -> str:
|
||||
"""Remove the 'class ClassName' line from the doc if present."""
|
||||
lines = doc.strip().split("\n")
|
||||
if lines and lines[0].strip() == f"class {class_name}":
|
||||
# Remove the class line and any blank line following it
|
||||
lines = lines[1:]
|
||||
while lines and not lines[0].strip():
|
||||
lines = lines[1:]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_docstring(doc: str, indent: str = " ") -> str:
|
||||
"""Format a doc string as a properly indented docstring."""
|
||||
lines = doc.strip().split("\n")
|
||||
|
||||
if len(lines) == 1:
|
||||
return f'{indent}"""{lines[0]}"""\n'
|
||||
else:
|
||||
result = [f'{indent}"""\n']
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
result.append(f"{indent}{line}\n")
|
||||
else:
|
||||
result.append("\n")
|
||||
result.append(f'{indent}"""\n')
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def run_ruff_format(filepath: str):
|
||||
"""Run ruff check --fix, ruff format, and doc-builder style on a file to ensure consistent formatting."""
|
||||
try:
|
||||
# First run ruff check --fix to fix any linting issues (including line length)
|
||||
subprocess.run(
|
||||
["ruff", "check", "--fix", filepath],
|
||||
check=False, # Don't fail if there are unfixable issues
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
# Then run ruff format for code formatting
|
||||
subprocess.run(
|
||||
["ruff", "format", filepath],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
# Finally run doc-builder style for docstring formatting
|
||||
subprocess.run(
|
||||
["doc-builder", "style", filepath, "--max_len", "119"],
|
||||
check=False, # Don't fail if doc-builder has issues
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print(f"Formatted {filepath}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Warning: formatting failed for {filepath}: {e.stderr}")
|
||||
except FileNotFoundError as e:
|
||||
print(f"Warning: tool not found ({e}). Skipping formatting.")
|
||||
except Exception as e:
|
||||
print(f"Warning: unexpected error formatting {filepath}: {e}")
|
||||
|
||||
|
||||
def get_existing_docstring(lines: list, class_line: int, docstring_end_line: int) -> str:
|
||||
"""Extract the existing docstring content from lines."""
|
||||
# class_line is 1-indexed, docstring starts at class_line (0-indexed: class_line)
|
||||
# and ends at docstring_end_line (1-indexed, inclusive)
|
||||
docstring_lines = lines[class_line:docstring_end_line]
|
||||
return "".join(docstring_lines)
|
||||
|
||||
|
||||
def process_file(filepath: str, overwrite: bool = False) -> list:
|
||||
"""
|
||||
Process a file and find/insert docstrings for # auto_docstring marked classes.
|
||||
|
||||
Returns list of classes that need updating.
|
||||
"""
|
||||
classes_to_update = find_auto_docstring_classes(filepath)
|
||||
|
||||
if not classes_to_update:
|
||||
return []
|
||||
|
||||
if not overwrite:
|
||||
# Check mode: only verify that docstrings exist
|
||||
# Content comparison is not reliable due to formatting differences
|
||||
classes_needing_update = []
|
||||
for class_name, class_line, has_docstring, docstring_end_line in classes_to_update:
|
||||
if not has_docstring:
|
||||
# No docstring exists, needs update
|
||||
classes_needing_update.append((filepath, class_name, class_line))
|
||||
return classes_needing_update
|
||||
|
||||
# Load the module to get doc properties
|
||||
module = load_module(filepath)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Process in reverse order to maintain line numbers
|
||||
updated = False
|
||||
for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update):
|
||||
doc = get_doc_from_class(module, class_name)
|
||||
|
||||
if doc is None:
|
||||
print(f"Warning: Could not get doc for {class_name} in {filepath}")
|
||||
continue
|
||||
|
||||
# Remove the "class ClassName" line since it's redundant in a docstring
|
||||
doc = strip_class_name_line(doc, class_name)
|
||||
|
||||
# Format the new docstring with 4-space indent
|
||||
new_docstring = format_docstring(doc, " ")
|
||||
|
||||
if has_docstring:
|
||||
# Replace existing docstring (line after class definition to docstring_end_line)
|
||||
# class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line
|
||||
lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:]
|
||||
else:
|
||||
# Insert new docstring right after class definition line
|
||||
# class_line is 1-indexed, so lines[class_line-1] is the class line
|
||||
# Insert at position class_line (which is right after the class line)
|
||||
lines = lines[:class_line] + [new_docstring] + lines[class_line:]
|
||||
|
||||
updated = True
|
||||
print(f"Updated docstring for {class_name} in {filepath}")
|
||||
|
||||
if updated:
|
||||
with open(filepath, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines)
|
||||
# Run ruff format to ensure consistent line wrapping
|
||||
run_ruff_format(filepath)
|
||||
|
||||
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
|
||||
|
||||
|
||||
def check_auto_docstrings(path: str = None, overwrite: bool = False):
|
||||
"""
|
||||
Check all files for # auto_docstring markers and optionally fix them.
|
||||
"""
|
||||
if path is None:
|
||||
path = DIFFUSERS_PATH
|
||||
|
||||
if os.path.isfile(path):
|
||||
all_files = [path]
|
||||
else:
|
||||
all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True)
|
||||
|
||||
all_markers = []
|
||||
|
||||
for filepath in all_files:
|
||||
markers = process_file(filepath, overwrite)
|
||||
all_markers.extend(markers)
|
||||
|
||||
if not overwrite and len(all_markers) > 0:
|
||||
message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers])
|
||||
raise ValueError(
|
||||
f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n"
|
||||
f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
if overwrite and len(all_markers) > 0:
|
||||
print(f"\nProcessed {len(all_markers)} docstring(s).")
|
||||
elif not overwrite and len(all_markers) == 0:
|
||||
print("All # auto_docstring markers have valid docstrings.")
|
||||
elif len(all_markers) == 0:
|
||||
print("No # auto_docstring markers found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Check and fix # auto_docstring markers in modular pipeline blocks",
|
||||
)
|
||||
parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)")
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite",
|
||||
action="store_true",
|
||||
help="Whether to fix the docstrings by inserting them from doc property.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
check_auto_docstrings(args.path, args.fix_and_overwrite)
|
||||
Reference in New Issue
Block a user