mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-13 07:05:48 +08:00
Compare commits
11 Commits
improve-lo
...
modular-do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c90ce33f2 | ||
|
|
507953f415 | ||
|
|
f0555af1c6 | ||
|
|
2a81f2ec54 | ||
|
|
d20f413f78 | ||
|
|
ff09bf1a63 | ||
|
|
34a743e2dc | ||
|
|
43ab14845d | ||
|
|
fbfe5c8d6b | ||
|
|
b29873dee7 | ||
|
|
7b499de6d0 |
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
## FirstBlockCacheConfig
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ The abstract from the paper is:
|
||||
|
||||
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
|
||||
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
|
||||
|
||||
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
|
||||
@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
|
||||
|
||||
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
|
||||
|
||||
- It receives the iteration variable from the loop wrapper.
|
||||
- It recieves the iteration variable from the loop wrapper.
|
||||
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
|
||||
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
|
||||
|
||||
|
||||
@@ -68,20 +68,6 @@ config = FasterCacheConfig(
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## FirstBlockCache
|
||||
|
||||
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
|
||||
)
|
||||
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
|
||||
```
|
||||
## TaylorSeer Cache
|
||||
|
||||
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
|
||||
@@ -101,7 +87,8 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = TaylorSeerCacheConfig(
|
||||
cache_interval=5,
|
||||
@@ -110,4 +97,4 @@ config = TaylorSeerCacheConfig(
|
||||
taylor_factors_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
```
|
||||
@@ -149,13 +149,13 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_images",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -140,7 +140,7 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
|
||||
|
||||
___Note___:
|
||||
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned separately and
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
|
||||
@@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
if not (has_lora_keys or has_norm_keys):
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.get(k)
|
||||
@@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_peft_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
|
||||
@@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
|
||||
@@ -41,11 +41,9 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
@@ -342,6 +343,185 @@ class InputParam:
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
@classmethod
|
||||
def template(cls, name: str) -> Optional["InputParam"]:
|
||||
"""Get template for name if exists, otherwise None."""
|
||||
if hasattr(cls, name) and callable(getattr(cls, name)):
|
||||
return getattr(cls, name)()
|
||||
return None
|
||||
|
||||
# ======================================================
|
||||
# InputParam templates
|
||||
# ======================================================
|
||||
|
||||
@classmethod
|
||||
def prompt(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="prompt", type_hint=str, required=True, description="The prompt or prompts to guide image generation."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def negative_prompt(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="negative_prompt",
|
||||
type_hint=str,
|
||||
default=None,
|
||||
description="The prompt or prompts not to guide the image generation.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def max_sequence_length(cls, default: int = 512) -> "InputParam":
|
||||
return cls(
|
||||
name="max_sequence_length",
|
||||
type_hint=int,
|
||||
default=default,
|
||||
description="Maximum sequence length for prompt encoding.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def height(cls, default: Optional[int] = None) -> "InputParam":
|
||||
return cls(
|
||||
name="height", type_hint=int, default=default, description="The height in pixels of the generated image."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def width(cls, default: Optional[int] = None) -> "InputParam":
|
||||
return cls(
|
||||
name="width", type_hint=int, default=default, description="The width in pixels of the generated image."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_inference_steps(cls, default: int = 50) -> "InputParam":
|
||||
return cls(
|
||||
name="num_inference_steps", type_hint=int, default=default, description="The number of denoising steps."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_images_per_prompt(cls, default: int = 1) -> "InputParam":
|
||||
return cls(
|
||||
name="num_images_per_prompt",
|
||||
type_hint=int,
|
||||
default=default,
|
||||
description="The number of images to generate per prompt.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generator(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="generator",
|
||||
type_hint=torch.Generator,
|
||||
default=None,
|
||||
description="Torch generator for deterministic generation.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sigmas(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="sigmas", type_hint=List[float], default=None, description="Custom sigmas for the denoising process."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def strength(cls, default: float = 0.9) -> "InputParam":
|
||||
return cls(name="strength", type_hint=float, default=default, description="Strength for img2img/inpainting.")
|
||||
|
||||
# images
|
||||
@classmethod
|
||||
def image(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="image",
|
||||
type_hint=PIL.Image.Image,
|
||||
required=True,
|
||||
description="Input image for img2img, editing, or conditioning.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mask_image(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="mask_image", type_hint=PIL.Image.Image, required=True, description="Mask image for inpainting."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def control_image(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="control_image",
|
||||
type_hint=PIL.Image.Image,
|
||||
required=True,
|
||||
description="Control image for ControlNet conditioning.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def padding_mask_crop(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="padding_mask_crop",
|
||||
type_hint=int,
|
||||
default=None,
|
||||
description="Padding for mask cropping in inpainting.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latents(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
default=None,
|
||||
description="Pre-generated noisy latents for image generation.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def timesteps(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="timesteps", type_hint=torch.Tensor, default=None, description="Timesteps for the denoising process."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def output_type(cls) -> "InputParam":
|
||||
return cls(name="output_type", type_hint=str, default="pil", description="Output format: 'pil', 'np', 'pt''.")
|
||||
|
||||
@classmethod
|
||||
def attention_kwargs(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="attention_kwargs",
|
||||
type_hint=Dict[str, Any],
|
||||
default=None,
|
||||
description="Additional kwargs for attention processors.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def denoiser_input_fields(cls) -> "InputParam":
|
||||
return cls(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
)
|
||||
|
||||
# ControlNet
|
||||
@classmethod
|
||||
def control_guidance_start(cls, default: float = 0.0) -> "InputParam":
|
||||
return cls(
|
||||
name="control_guidance_start",
|
||||
type_hint=float,
|
||||
default=default,
|
||||
description="When to start applying ControlNet.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def control_guidance_end(cls, default: float = 1.0) -> "InputParam":
|
||||
return cls(
|
||||
name="control_guidance_end",
|
||||
type_hint=float,
|
||||
default=default,
|
||||
description="When to stop applying ControlNet.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 1.0) -> "InputParam":
|
||||
return cls(
|
||||
name="controlnet_conditioning_scale",
|
||||
type_hint=float,
|
||||
default=default,
|
||||
description="Scale for ControlNet conditioning.",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
@@ -357,6 +537,25 @@ class OutputParam:
|
||||
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def template(cls, name: str) -> Optional["OutputParam"]:
|
||||
"""Get template for name if exists, otherwise None."""
|
||||
if hasattr(cls, name) and callable(getattr(cls, name)):
|
||||
return getattr(cls, name)()
|
||||
return None
|
||||
|
||||
# ======================================================
|
||||
# OutputParam templates
|
||||
# ======================================================
|
||||
|
||||
@classmethod
|
||||
def images(cls) -> "OutputParam":
|
||||
return cls(name="images", type_hint=List[PIL.Image.Image], description="Generated images.")
|
||||
|
||||
@classmethod
|
||||
def latents(cls) -> "OutputParam":
|
||||
return cls(name="latents", type_hint=torch.Tensor, description="Denoised latents.")
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
"""
|
||||
|
||||
@@ -134,11 +134,11 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam.latents(),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam.generator(),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
@@ -225,12 +225,14 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="layers", default=4),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam.latents(),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
InputParam(
|
||||
name="layers", type_hint=int, default=4, description="Number of layers to extract from the image"
|
||||
),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam.generator(),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
@@ -455,7 +457,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -466,8 +468,8 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam.num_inference_steps(),
|
||||
InputParam.sigmas(),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
@@ -532,8 +534,8 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50, type_hint=int),
|
||||
InputParam("sigmas", type_hint=List[float]),
|
||||
InputParam.num_inference_steps(),
|
||||
InputParam.sigmas(),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@@ -579,7 +581,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -590,15 +592,15 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam.num_inference_steps(),
|
||||
InputParam.sigmas(),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
InputParam(name="strength", default=0.9),
|
||||
InputParam.strength(0.9),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -886,7 +888,7 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="layers", required=True),
|
||||
InputParam(name="layers", default=4, description="Number of layers to extract from the image"),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
@@ -971,9 +973,9 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("control_guidance_start", default=0.0),
|
||||
InputParam("control_guidance_end", default=1.0),
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam.control_guidance_start(),
|
||||
InputParam.control_guidance_end(),
|
||||
InputParam.controlnet_conditioning_scale(),
|
||||
InputParam("control_image_latents", required=True),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -12,10 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -91,7 +89,7 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("height", required=True, type_hint=int),
|
||||
InputParam("width", required=True, type_hint=int),
|
||||
InputParam("layers", required=True, type_hint=int),
|
||||
InputParam("layers", default=4, description="Number of layers to extract from the image"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -140,13 +138,7 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
return [OutputParam.images()]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -198,14 +190,19 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
InputParam.output_type(),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
OutputParam.images(),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -273,12 +270,7 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
InputParam.output_type(),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -323,12 +315,7 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
InputParam.output_type(),
|
||||
InputParam("mask_overlay_kwargs"),
|
||||
]
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam.attention_kwargs(),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -231,10 +231,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam.denoiser_input_fields(),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
@@ -322,7 +319,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam.attention_kwargs(),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -335,10 +332,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam.denoiser_input_fields(),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
@@ -424,7 +418,7 @@ class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -301,8 +301,12 @@ class QwenImageEditResizeStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
|
||||
InputParam.template(self._image_input_name)
|
||||
or InputParam(
|
||||
name=self._image_input_name,
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Input image for conditioning",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -381,7 +385,8 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
InputParam.template(self._image_input_name)
|
||||
or InputParam(
|
||||
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
|
||||
),
|
||||
InputParam(
|
||||
@@ -484,7 +489,8 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
InputParam.template(self._image_input_name)
|
||||
or InputParam(
|
||||
name=self._image_input_name,
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
@@ -564,7 +570,9 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", type_hint=str, description="The prompt to encode"),
|
||||
InputParam(
|
||||
name="prompt", type_hint=str, description="The prompt to encode"
|
||||
), # it is not required for qwenimage-layered, unlike other pipelines
|
||||
InputParam(
|
||||
name="resized_image",
|
||||
required=True,
|
||||
@@ -647,11 +655,9 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam(
|
||||
name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
|
||||
),
|
||||
InputParam.prompt(),
|
||||
InputParam.negative_prompt(),
|
||||
InputParam.max_sequence_length(1024),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -772,8 +778,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam.prompt(),
|
||||
InputParam.negative_prompt(),
|
||||
InputParam(
|
||||
name="resized_image",
|
||||
required=True,
|
||||
@@ -895,8 +901,8 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam.prompt(),
|
||||
InputParam.negative_prompt(),
|
||||
InputParam(
|
||||
name="resized_cond_image",
|
||||
required=True,
|
||||
@@ -1010,11 +1016,11 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("padding_mask_crop"),
|
||||
InputParam.mask_image(),
|
||||
InputParam.image(),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
InputParam.padding_mask_crop(),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1082,9 +1088,14 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("resized_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
InputParam.mask_image(),
|
||||
InputParam(
|
||||
"resized_image",
|
||||
required=True,
|
||||
type_hint=PIL.Image.Image,
|
||||
description="The resized image. should be generated using a resize step",
|
||||
),
|
||||
InputParam.padding_mask_crop(),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1140,9 +1151,9 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam.image(),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1312,7 +1323,10 @@ class QwenImageVaeEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
return [
|
||||
InputParam.template(self._image_input_name) or InputParam(name=self._image_input_name, required=True),
|
||||
InputParam.generator(),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
@@ -1383,10 +1397,10 @@ class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam("control_image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
InputParam.control_image(),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
InputParam.generator(),
|
||||
]
|
||||
return inputs
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
|
||||
@@ -269,17 +269,17 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
inputs.append(InputParam.template(input_name) or InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -398,17 +398,17 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
inputs.append(InputParam.template(input_name) or InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -544,15 +544,15 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam(name="batch_size", required=True),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
inputs.append(InputParam.template(input_name) or InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -638,9 +638,9 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
return [
|
||||
InputParam(name="control_image_latents", required=True),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam.num_images_per_prompt(),
|
||||
InputParam.height(),
|
||||
InputParam.width(),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,10 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
@@ -59,8 +56,61 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts."""
|
||||
"""
|
||||
class QwenImageEditVLEncoderStep
|
||||
|
||||
QwenImage-Edit VL encoder step that encode the image and text prompts together.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`) [subfolder=]
|
||||
|
||||
processor (`Qwen2VLProcessor`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
Configs:
|
||||
|
||||
prompt_template_encode (default: <|im_start|>system
|
||||
Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how
|
||||
the user's text instruction should alter or modify the image. Generate a new image that meets the user's
|
||||
requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user
|
||||
<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|> <|im_start|>assistant )
|
||||
|
||||
prompt_template_encode_start_idx (default: 64)
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings
|
||||
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask
|
||||
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings
|
||||
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
@@ -80,7 +130,40 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditVaeEncoderStep
|
||||
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
|
||||
processed_image (`None`):
|
||||
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image(s). Single tensor or list depending on input.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
@@ -95,7 +178,54 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit Inpaint VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditInpaintVaeEncoderStep
|
||||
|
||||
This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:
|
||||
- resize the image for target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
- process the resized image and mask image.
|
||||
- create image latents.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
image_mask_processor (`InpaintProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
|
||||
processed_image (`None`):
|
||||
|
||||
processed_mask_image (`None`):
|
||||
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image(s). Single tensor or list depending on input.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
@@ -137,7 +267,55 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditInputStep
|
||||
|
||||
Input step that prepares the inputs for the edit denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -154,7 +332,57 @@ class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditInpaintInputStep
|
||||
|
||||
Input step that prepares the inputs for the edit inpaint denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
processed_mask_image (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -174,7 +402,51 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditInpaintPrepareLatentsStep
|
||||
|
||||
This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:
|
||||
- Add noise to the image latents to create the latents input for the denoiser.
|
||||
- Create the patchified latents `mask` based on the processed mask image.
|
||||
|
||||
Components:
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
The initial random noised, can be generated in prepare latent step.
|
||||
|
||||
image_latents (`Tensor`):
|
||||
The image latents to use for the denoising process. Can be generated in vae encoder and packed in input
|
||||
step.
|
||||
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask to use for the inpainting process.
|
||||
|
||||
height (`None`):
|
||||
|
||||
width (`None`):
|
||||
|
||||
dtype (`None`):
|
||||
|
||||
Outputs:
|
||||
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
@@ -189,7 +461,68 @@ class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditCoreDenoiseStep
|
||||
|
||||
Core denoising workflow for QwenImage-Edit edit (img2img) task.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInputStep(),
|
||||
@@ -212,9 +545,81 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditInpaintCoreDenoiseStep
|
||||
|
||||
Core denoising workflow for QwenImage-Edit edit inpaint task.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
processed_mask_image (`None`, *optional*):
|
||||
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintInputStep(),
|
||||
@@ -239,6 +644,12 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
|
||||
# Auto core denoise step for QwenImage Edit
|
||||
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
@@ -267,6 +678,12 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
@@ -274,7 +691,33 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
|
||||
|
||||
# Decode step (standard)
|
||||
# auto_docstring
|
||||
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditDecodeStep
|
||||
|
||||
Decode step that decodes the latents to images and postprocess the generated image.
|
||||
|
||||
Components:
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
The latents to decode, can be generated in the denoise step
|
||||
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt''.
|
||||
|
||||
Outputs:
|
||||
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -285,7 +728,36 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
# auto_docstring
|
||||
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditInpaintDecodeStep
|
||||
|
||||
Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask
|
||||
overlay to the original image.
|
||||
|
||||
Components:
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
image_mask_processor (`InpaintProcessor`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
The latents to decode, can be generated in the denoise step
|
||||
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt''.
|
||||
|
||||
mask_overlay_kwargs (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -313,9 +785,7 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
|
||||
@@ -333,7 +803,110 @@ EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditAutoBlocks
|
||||
|
||||
Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.
|
||||
- for edit (img2img) generation, you need to provide `image`
|
||||
- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide
|
||||
`padding_mask_crop`
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`) [subfolder=]
|
||||
|
||||
processor (`Qwen2VLProcessor`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
image_mask_processor (`InpaintProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Configs:
|
||||
|
||||
prompt_template_encode (default: <|im_start|>system
|
||||
Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how
|
||||
the user's text instruction should alter or modify the image. Generate a new image that meets the user's
|
||||
requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user
|
||||
<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|> <|im_start|>assistant )
|
||||
|
||||
prompt_template_encode_start_idx (default: 64)
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
mask_image (`Image`, *optional*):
|
||||
Mask image for inpainting.
|
||||
|
||||
padding_mask_crop (`int`, *optional*):
|
||||
Padding for mask cropping in inpainting.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`):
|
||||
|
||||
processed_mask_image (`None`, *optional*):
|
||||
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt''.
|
||||
|
||||
mask_overlay_kwargs (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
@@ -349,5 +922,5 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
OutputParam.images(),
|
||||
]
|
||||
|
||||
@@ -12,10 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
@@ -53,8 +49,63 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
|
||||
"""
|
||||
class QwenImageEditPlusVLEncoderStep
|
||||
|
||||
QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`) [subfolder=]
|
||||
|
||||
processor (`Qwen2VLProcessor`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
Configs:
|
||||
|
||||
prompt_template_encode (default: <|im_start|>system
|
||||
Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how
|
||||
the user's text instruction should alter or modify the image. Generate a new image that meets the user's
|
||||
requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user
|
||||
{}<|im_end|> <|im_start|>assistant )
|
||||
|
||||
img_template_encode (default: Picture {}: <|vision_start|><|image_pad|><|vision_end|>)
|
||||
|
||||
prompt_template_encode_start_idx (default: 64)
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_cond_image (`List`):
|
||||
The resized images
|
||||
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings
|
||||
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask
|
||||
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings
|
||||
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
@@ -73,8 +124,40 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
|
||||
"""
|
||||
class QwenImageEditPlusVaeEncoderStep
|
||||
|
||||
VAE encoder step that encodes image inputs into latent representations. Each image is resized independently based
|
||||
on its own aspect ratio to 1024x1024 target area.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
|
||||
processed_image (`None`):
|
||||
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image(s). Single tensor or list depending on input.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
@@ -98,7 +181,57 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditPlusInputStep
|
||||
|
||||
Input step that prepares the inputs for the Edit Plus denoising step. It:
|
||||
- Standardizes text embeddings batch size.
|
||||
- Processes list of image latents: patchifies, concatenates along dim=1, expands batch.
|
||||
- Outputs lists of image_height/image_width for RoPE calculation.
|
||||
- Defaults height/width from last image in the list.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
|
||||
image_height (`List`):
|
||||
The image heights calculated from the image latents dimension
|
||||
|
||||
image_width (`List`):
|
||||
The image widths calculated from the image latents dimension
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -118,7 +251,68 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Edit Plus (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditPlusCoreDenoiseStep
|
||||
|
||||
Core denoising workflow for QwenImage-Edit Plus edit (img2img) task.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusInputStep(),
|
||||
@@ -144,9 +338,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
|
||||
@@ -155,7 +347,33 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditPlusDecodeStep
|
||||
|
||||
Decode step that decodes the latents to images and postprocesses the generated image.
|
||||
|
||||
Components:
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
The latents to decode, can be generated in the denoise step
|
||||
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt''.
|
||||
|
||||
Outputs:
|
||||
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
@@ -179,7 +397,95 @@ EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageEditPlusAutoBlocks
|
||||
|
||||
Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.
|
||||
- `image` is required input (can be single image or list of images).
|
||||
- Each image is resized independently based on its own aspect ratio.
|
||||
- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`) [subfolder=]
|
||||
|
||||
processor (`Qwen2VLProcessor`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
pachifier (`QwenImagePachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Configs:
|
||||
|
||||
prompt_template_encode (default: <|im_start|>system
|
||||
Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how
|
||||
the user's text instruction should alter or modify the image. Generate a new image that meets the user's
|
||||
requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user
|
||||
{}<|im_end|> <|im_start|>assistant )
|
||||
|
||||
img_template_encode (default: Picture {}: <|vision_start|><|image_pad|><|vision_end|>)
|
||||
|
||||
prompt_template_encode_start_idx (default: 64)
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt''.
|
||||
|
||||
Outputs:
|
||||
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
|
||||
@@ -196,5 +502,5 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
OutputParam.images(),
|
||||
]
|
||||
|
||||
@@ -13,11 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
@@ -55,8 +50,102 @@ logger = logging.get_logger(__name__)
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
|
||||
"""
|
||||
class QwenImageLayeredTextEncoderStep
|
||||
|
||||
QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not
|
||||
provided.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`) [subfolder=]
|
||||
|
||||
processor (`Qwen2VLProcessor`) [subfolder=]
|
||||
|
||||
tokenizer (`Qwen2Tokenizer`): The tokenizer to use [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
Configs:
|
||||
|
||||
image_caption_prompt_en (default: <|im_start|>system
|
||||
You are a helpful assistant.<|im_end|> <|im_start|>user # Image Annotator You are a professional image annotator.
|
||||
Please write an image caption based on the input image:
|
||||
1. Write the caption using natural, descriptive language without structured formats or rich text.
|
||||
2. Enrich caption details by including:
|
||||
- Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on
|
||||
- Vision Relations between objects, such as spatial relations, functional relations, possessive relations,
|
||||
attachment relations, action relations, comparative relations, causal relations, and so on
|
||||
- Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on
|
||||
- Identify the text clearly visible in the image, without translation or explanation, and highlight it in the
|
||||
caption with quotation marks
|
||||
3. Maintain authenticity and accuracy:
|
||||
- Avoid generalizations
|
||||
- Describe all visible information in the image, while do not add information not explicitly shown in the image
|
||||
<|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant )
|
||||
|
||||
image_caption_prompt_cn (default: <|im_start|>system
|
||||
You are a helpful assistant.<|im_end|> <|im_start|>user # 图像标注器 你是一个专业的图像标注器。请基于输入图像,撰写图注:
|
||||
1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。
|
||||
2. 通过加入以下内容,丰富图注细节:
|
||||
- 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等
|
||||
- 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等
|
||||
- 环境细节:例如天气、光照、颜色、纹理、气氛等
|
||||
- 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调
|
||||
3. 保持真实性与准确性:
|
||||
- 不要使用笼统的描述
|
||||
- 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容
|
||||
<|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant )
|
||||
|
||||
prompt_template_encode (default: <|im_start|>system
|
||||
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the
|
||||
objects and background:<|im_end|> <|im_start|>user {}<|im_end|> <|im_start|>assistant )
|
||||
|
||||
prompt_template_encode_start_idx (default: 34)
|
||||
|
||||
tokenizer_max_length (default: 1024)
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
|
||||
prompt (`str`, *optional*):
|
||||
The prompt to encode
|
||||
|
||||
use_en_prompt (`bool`, *optional*, defaults to False):
|
||||
Whether to use English prompt template
|
||||
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings
|
||||
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask
|
||||
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings
|
||||
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
@@ -77,7 +166,43 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
# auto_docstring
|
||||
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageLayeredVaeEncoderStep
|
||||
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
|
||||
resized_image (`List`):
|
||||
The resized images
|
||||
|
||||
processed_image (`None`):
|
||||
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image(s). Single tensor or list depending on input.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
@@ -98,7 +223,55 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# assemble input steps
|
||||
# auto_docstring
|
||||
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageLayeredInputStep
|
||||
|
||||
Input step that prepares the inputs for the layered denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs.
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImageLayeredPachifier`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
|
||||
image_height (`int`):
|
||||
The image height calculated from the image latents dimension
|
||||
|
||||
image_width (`int`):
|
||||
The image width calculated from the image latents dimension
|
||||
|
||||
height (`int`):
|
||||
The height of the image output
|
||||
|
||||
width (`int`):
|
||||
The width of the image output
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
@@ -116,7 +289,65 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image) core denoise step
|
||||
# auto_docstring
|
||||
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageLayeredCoreDenoiseStep
|
||||
|
||||
Core denoising workflow for QwenImage-Layered img2img task.
|
||||
|
||||
Components:
|
||||
|
||||
pachifier (`QwenImageLayeredPachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Inputs:
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
prompt_embeds (`None`):
|
||||
|
||||
prompt_embeds_mask (`None`):
|
||||
|
||||
negative_prompt_embeds (`None`, *optional*):
|
||||
|
||||
negative_prompt_embeds_mask (`None`, *optional*):
|
||||
|
||||
image_latents (`None`, *optional*):
|
||||
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredInputStep(),
|
||||
@@ -142,9 +373,7 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
OutputParam.latents(),
|
||||
]
|
||||
|
||||
|
||||
@@ -162,7 +391,127 @@ LAYERED_AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
class QwenImageLayeredAutoBlocks
|
||||
|
||||
Auto Modular pipeline for layered denoising tasks using QwenImage-Layered.
|
||||
|
||||
Components:
|
||||
|
||||
image_resize_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`) [subfolder=]
|
||||
|
||||
processor (`Qwen2VLProcessor`) [subfolder=]
|
||||
|
||||
tokenizer (`Qwen2Tokenizer`): The tokenizer to use [subfolder=]
|
||||
|
||||
guider (`ClassifierFreeGuidance`) [subfolder=]
|
||||
|
||||
image_processor (`VaeImageProcessor`) [subfolder=]
|
||||
|
||||
vae (`AutoencoderKLQwenImage`) [subfolder=]
|
||||
|
||||
pachifier (`QwenImageLayeredPachifier`) [subfolder=]
|
||||
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
|
||||
|
||||
transformer (`QwenImageTransformer2DModel`) [subfolder=]
|
||||
|
||||
Configs:
|
||||
|
||||
image_caption_prompt_en (default: <|im_start|>system
|
||||
You are a helpful assistant.<|im_end|> <|im_start|>user # Image Annotator You are a professional image annotator.
|
||||
Please write an image caption based on the input image:
|
||||
1. Write the caption using natural, descriptive language without structured formats or rich text.
|
||||
2. Enrich caption details by including:
|
||||
- Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on
|
||||
- Vision Relations between objects, such as spatial relations, functional relations, possessive relations,
|
||||
attachment relations, action relations, comparative relations, causal relations, and so on
|
||||
- Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on
|
||||
- Identify the text clearly visible in the image, without translation or explanation, and highlight it in the
|
||||
caption with quotation marks
|
||||
3. Maintain authenticity and accuracy:
|
||||
- Avoid generalizations
|
||||
- Describe all visible information in the image, while do not add information not explicitly shown in the image
|
||||
<|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant )
|
||||
|
||||
image_caption_prompt_cn (default: <|im_start|>system
|
||||
You are a helpful assistant.<|im_end|> <|im_start|>user # 图像标注器 你是一个专业的图像标注器。请基于输入图像,撰写图注:
|
||||
1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。
|
||||
2. 通过加入以下内容,丰富图注细节:
|
||||
- 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等
|
||||
- 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等
|
||||
- 环境细节:例如天气、光照、颜色、纹理、气氛等
|
||||
- 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调
|
||||
3. 保持真实性与准确性:
|
||||
- 不要使用笼统的描述
|
||||
- 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容
|
||||
<|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant )
|
||||
|
||||
prompt_template_encode (default: <|im_start|>system
|
||||
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the
|
||||
objects and background:<|im_end|> <|im_start|>user {}<|im_end|> <|im_start|>assistant )
|
||||
|
||||
prompt_template_encode_start_idx (default: 34)
|
||||
|
||||
tokenizer_max_length (default: 1024)
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`Image`):
|
||||
Input image for img2img, editing, or conditioning.
|
||||
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
|
||||
prompt (`str`, *optional*):
|
||||
The prompt to encode
|
||||
|
||||
use_en_prompt (`bool`, *optional*, defaults to False):
|
||||
Whether to use English prompt template
|
||||
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
|
||||
layers (`int`, *optional*, defaults to 4):
|
||||
Number of layers to extract from the image
|
||||
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
|
||||
sigmas (`List`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
**denoiser_input_fields (`Tensor`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt''.
|
||||
|
||||
Outputs:
|
||||
|
||||
images (`List`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = LAYERED_AUTO_BLOCKS.values()
|
||||
block_names = LAYERED_AUTO_BLOCKS.keys()
|
||||
@@ -174,5 +523,5 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
OutputParam.images(),
|
||||
]
|
||||
|
||||
@@ -129,10 +129,7 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam.denoiser_input_fields(),
|
||||
]
|
||||
guider_input_names = []
|
||||
uncond_guider_input_names = []
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modifications by Decart AI Team:
|
||||
# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension.
|
||||
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -76,8 +76,6 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -116,3 +114,23 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -87,8 +87,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -149,6 +147,26 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -85,8 +85,6 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -164,3 +162,23 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -66,8 +66,6 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
|
||||
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -148,3 +146,23 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,8 +117,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder_2",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -174,6 +172,26 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -150,8 +150,6 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 5, 32, 32, 3)
|
||||
@@ -269,3 +267,27 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@@ -76,8 +76,6 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -127,3 +125,23 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -74,8 +74,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 4, 4, 3)
|
||||
@@ -115,6 +113,26 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
|
||||
@@ -67,8 +67,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 7, 16, 16, 3)
|
||||
@@ -119,6 +117,26 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -69,8 +69,6 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -109,3 +107,23 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -75,8 +75,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -119,6 +117,26 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -73,8 +73,6 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -123,3 +121,23 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -85,8 +85,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -141,6 +139,26 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
|
||||
@@ -75,8 +75,6 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -265,3 +263,23 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,7 +117,6 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
|
||||
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
|
||||
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
|
||||
supports_text_encoder_loras = True
|
||||
|
||||
unet_kwargs = None
|
||||
transformer_cls = None
|
||||
@@ -334,9 +333,6 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -461,9 +457,6 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -501,9 +494,6 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -565,9 +555,6 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -606,9 +593,6 @@ class PeftLoraLoaderMixinTests:
|
||||
with different ranks and some adapters removed
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
|
||||
text_lora_config = LoraConfig(
|
||||
@@ -667,9 +651,6 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
286
utils/modular_auto_docstring.py
Normal file
286
utils/modular_auto_docstring.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Auto Docstring Generator for Modular Pipeline Blocks
|
||||
|
||||
This script scans Python files for classes that have `# auto_docstring` comment above them
|
||||
and inserts/updates the docstring from the class's `doc` property.
|
||||
|
||||
Run from the root of the repo:
|
||||
python utils/modular_auto_docstring.py [path] [--fix_and_overwrite]
|
||||
|
||||
Examples:
|
||||
# Check for auto_docstring markers (will error if found without proper docstring)
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Check specific directory
|
||||
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/
|
||||
|
||||
# Fix and overwrite the docstrings
|
||||
python utils/modular_auto_docstring.py --fix_and_overwrite
|
||||
|
||||
Usage in code:
|
||||
# auto_docstring
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
# docstring will be automatically inserted here
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return "Your docstring content..."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo
|
||||
DIFFUSERS_PATH = "src/diffusers"
|
||||
REPO_PATH = "."
|
||||
|
||||
# Pattern to match the auto_docstring comment
|
||||
AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$")
|
||||
|
||||
|
||||
def setup_diffusers_import():
|
||||
"""Setup import path to use the local diffusers module."""
|
||||
src_path = os.path.join(REPO_PATH, "src")
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
|
||||
def get_module_from_filepath(filepath: str) -> str:
|
||||
"""Convert a filepath to a module name."""
|
||||
filepath = os.path.normpath(filepath)
|
||||
|
||||
if filepath.startswith("src" + os.sep):
|
||||
filepath = filepath[4:]
|
||||
|
||||
if filepath.endswith(".py"):
|
||||
filepath = filepath[:-3]
|
||||
|
||||
module_name = filepath.replace(os.sep, ".")
|
||||
return module_name
|
||||
|
||||
|
||||
def load_module(filepath: str):
|
||||
"""Load a module from filepath."""
|
||||
setup_diffusers_import()
|
||||
module_name = get_module_from_filepath(filepath)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not import module {module_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_doc_from_class(module, class_name: str) -> str:
|
||||
"""Get the doc property from an instantiated class."""
|
||||
if module is None:
|
||||
return None
|
||||
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
instance = cls()
|
||||
if hasattr(instance, "doc"):
|
||||
return instance.doc
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not instantiate {class_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_auto_docstring_classes(filepath: str) -> list:
|
||||
"""
|
||||
Find all classes in a file that have # auto_docstring comment above them.
|
||||
|
||||
Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line)
|
||||
"""
|
||||
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Parse AST to find class locations and their docstrings
|
||||
content = "".join(lines)
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in {filepath}: {e}")
|
||||
return []
|
||||
|
||||
# Build a map of class_name -> (class_line, has_docstring, docstring_end_line)
|
||||
class_info = {}
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
has_docstring = False
|
||||
docstring_end_line = node.lineno # default to class line
|
||||
|
||||
if node.body and isinstance(node.body[0], ast.Expr):
|
||||
first_stmt = node.body[0]
|
||||
if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str):
|
||||
has_docstring = True
|
||||
docstring_end_line = first_stmt.end_lineno or first_stmt.lineno
|
||||
|
||||
class_info[node.name] = (node.lineno, has_docstring, docstring_end_line)
|
||||
|
||||
# Now scan for # auto_docstring comments
|
||||
classes_to_update = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if AUTO_DOCSTRING_PATTERN.match(line):
|
||||
# Found the marker, look for class definition on next non-empty, non-comment line
|
||||
j = i + 1
|
||||
while j < len(lines):
|
||||
next_line = lines[j].strip()
|
||||
if next_line and not next_line.startswith("#"):
|
||||
break
|
||||
j += 1
|
||||
|
||||
if j < len(lines) and lines[j].strip().startswith("class "):
|
||||
# Extract class name
|
||||
match = re.match(r"class\s+(\w+)", lines[j].strip())
|
||||
if match:
|
||||
class_name = match.group(1)
|
||||
if class_name in class_info:
|
||||
class_line, has_docstring, docstring_end_line = class_info[class_name]
|
||||
classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line))
|
||||
|
||||
return classes_to_update
|
||||
|
||||
|
||||
def format_docstring(doc: str, indent: str = " ") -> str:
|
||||
"""Format a doc string as a properly indented docstring."""
|
||||
lines = doc.strip().split("\n")
|
||||
|
||||
if len(lines) == 1:
|
||||
return f'{indent}"""{lines[0]}"""\n'
|
||||
else:
|
||||
result = [f'{indent}"""\n']
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
result.append(f"{indent}{line}\n")
|
||||
else:
|
||||
result.append("\n")
|
||||
result.append(f'{indent}"""\n')
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def process_file(filepath: str, overwrite: bool = False) -> list:
|
||||
"""
|
||||
Process a file and find/insert docstrings for # auto_docstring marked classes.
|
||||
|
||||
Returns list of classes that need updating.
|
||||
"""
|
||||
classes_to_update = find_auto_docstring_classes(filepath)
|
||||
|
||||
if not classes_to_update:
|
||||
return []
|
||||
|
||||
if not overwrite:
|
||||
# Just return the list of classes that need updating
|
||||
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
|
||||
|
||||
# Load the module to get doc properties
|
||||
module = load_module(filepath)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Process in reverse order to maintain line numbers
|
||||
updated = False
|
||||
for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update):
|
||||
doc = get_doc_from_class(module, class_name)
|
||||
|
||||
if doc is None:
|
||||
print(f"Warning: Could not get doc for {class_name} in {filepath}")
|
||||
continue
|
||||
|
||||
# Format the new docstring with 4-space indent
|
||||
new_docstring = format_docstring(doc, " ")
|
||||
|
||||
if has_docstring:
|
||||
# Replace existing docstring (line after class definition to docstring_end_line)
|
||||
# class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line
|
||||
lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:]
|
||||
else:
|
||||
# Insert new docstring right after class definition line
|
||||
# class_line is 1-indexed, so lines[class_line-1] is the class line
|
||||
# Insert at position class_line (which is right after the class line)
|
||||
lines = lines[:class_line] + [new_docstring] + lines[class_line:]
|
||||
|
||||
updated = True
|
||||
print(f"Updated docstring for {class_name} in {filepath}")
|
||||
|
||||
if updated:
|
||||
with open(filepath, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
|
||||
|
||||
|
||||
def check_auto_docstrings(path: str = None, overwrite: bool = False):
|
||||
"""
|
||||
Check all files for # auto_docstring markers and optionally fix them.
|
||||
"""
|
||||
if path is None:
|
||||
path = DIFFUSERS_PATH
|
||||
|
||||
if os.path.isfile(path):
|
||||
all_files = [path]
|
||||
else:
|
||||
all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True)
|
||||
|
||||
all_markers = []
|
||||
|
||||
for filepath in all_files:
|
||||
markers = process_file(filepath, overwrite)
|
||||
all_markers.extend(markers)
|
||||
|
||||
if not overwrite and len(all_markers) > 0:
|
||||
message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers])
|
||||
raise ValueError(
|
||||
f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n"
|
||||
f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
if overwrite and len(all_markers) > 0:
|
||||
print(f"\nUpdated {len(all_markers)} docstring(s).")
|
||||
elif len(all_markers) == 0:
|
||||
print("No # auto_docstring markers found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Check and fix # auto_docstring markers in modular pipeline blocks",
|
||||
)
|
||||
parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)")
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite",
|
||||
action="store_true",
|
||||
help="Whether to fix the docstrings by inserting them from doc property.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
check_auto_docstrings(args.path, args.fix_and_overwrite)
|
||||
Reference in New Issue
Block a user