Compare commits

...

32 Commits

Author SHA1 Message Date
YiYi Xu
b73cc50e48 Merge branch 'main' into modular-workflow 2026-01-31 09:51:11 -10:00
YiYi Xu
769a1f3a12 [Modular]add a real quick start guide (#13029)
* add a real quick start guide

* Update docs/source/en/modular_diffusers/quickstart.md

* update a bit more

* fix

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/quickstart.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/quickstart.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update more

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* address more feedbacks: move components amnager earlier, explain blocks vs sub-blocks etc

* more

* remove the link to mellon guide, not exist in this PR yet

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-31 09:43:20 -10:00
Mikko Lauri
ec6b2bcccb Fix aiter availability check (#13059)
Update import_utils.py
2026-01-30 19:24:05 +05:30
yiyixuxu
20c35da75c up up 2026-01-25 12:11:37 +01:00
yiyixuxu
6a549f5f55 initial support: workflow 2026-01-25 11:40:52 +01:00
Sayak Paul
412e51c856 include auto-docstring check in the modular ci. (#13004) 2026-01-23 22:34:24 -10:00
github-actions[bot]
23d06423ab Apply style fixes 2026-01-19 09:23:31 +00:00
YiYi Xu
aba551c868 Merge branch 'main' into modular-doc-improv 2026-01-18 23:20:36 -10:00
yiyixuxu
1f9576a2ca fix 2026-01-19 09:56:14 +01:00
yiyixuxu
d75fbc43c7 Merge branch 'modular-doc-improv' of github.com:huggingface/diffusers into modular-doc-improv 2026-01-19 09:54:46 +01:00
yiyixuxu
b7127ce7a7 revert change in z 2026-01-19 09:54:40 +01:00
YiYi Xu
7e9d2b954e Apply suggestions from code review 2026-01-18 22:44:44 -10:00
yiyixuxu
94525200fd rmove space in make docstring 2026-01-19 09:35:39 +01:00
yiyixuxu
f056af1fbb make style 2026-01-19 09:27:40 +01:00
yiyixuxu
8d45ff5bf6 apply auto docstring 2026-01-19 09:22:04 +01:00
yiyixuxu
fb15752d55 up up up 2026-01-19 08:10:31 +01:00
yiyixuxu
1f2dbc9dd2 up 2026-01-19 04:10:17 +01:00
yiyixuxu
002c3e8239 add template method 2026-01-19 03:24:34 +01:00
yiyixuxu
de03d7f100 refactor based on dhruv's feedback: remove the class method 2026-01-18 00:35:01 +01:00
yiyixuxu
25c968a38f add TODO in the description for empty docstring 2026-01-17 09:57:56 +01:00
yiyixuxu
aea0d046f6 address feedbacks 2026-01-17 09:36:58 +01:00
yiyixuxu
1c90ce33f2 up 2026-01-10 12:21:26 +01:00
yiyixuxu
507953f415 more more 2026-01-10 12:19:14 +01:00
yiyixuxu
f0555af1c6 up up up 2026-01-10 12:15:53 +01:00
yiyixuxu
2a81f2ec54 style 2026-01-10 12:15:36 +01:00
yiyixuxu
d20f413f78 more auto docstring 2026-01-10 12:11:28 +01:00
yiyixuxu
ff09bf1a63 add modular_auto_docstring! 2026-01-10 11:55:03 +01:00
yiyixuxu
34a743e2dc style 2026-01-10 10:57:27 +01:00
yiyixuxu
43ab14845d update outputs 2026-01-10 10:56:54 +01:00
YiYi Xu
fbfe5c8d6b Merge branch 'main' into modular-doc-improv 2026-01-09 23:54:23 -10:00
yiyixuxu
b29873dee7 up up 2026-01-10 10:52:53 +01:00
yiyixuxu
7b499de6d0 up 2026-01-10 03:35:15 +01:00
6 changed files with 537 additions and 489 deletions

View File

@@ -24,7 +24,7 @@ The Modular Diffusers docs are organized as shown below.
## Quickstart
- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers.
- The [quickstart](./quickstart) shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
## ModularPipelineBlocks

View File

@@ -12,333 +12,248 @@ specific language governing permissions and limitations under the License.
# Quickstart
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use.
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. These pipelines can go beyond what standard `DiffusionPipeline`s can do. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface for running generation tasks.
This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework.
This guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
## ModularPipelineBlocks
[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
- [`ModularPipelineBlocks`] is the most basic block for a single step.
- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block.
- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows.
- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation.
```py
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
IMAGE2IMAGE_BLOCKS = InsertableDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep)
])
```
## Pipeline and block states
Modular Diffusers uses *state* to communicate data between blocks. There are two types of states.
- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks.
- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block.
## Customizing blocks
[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two.
Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones.
Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
```py
denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
print(denoise_blocks)
```
Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block.
```py
# Copy existing blocks as placeholders
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
"""Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
# ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
```
### prepare_latents
The `prepare_latents` block requires the following changes.
- a processor to process the change map
- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions
- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`]
```diff
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("generator"),
+ InputParam("diffdiff_map", required=True),
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
+ InputParam("timesteps", type_hint=torch.Tensor),
+ InputParam("num_inference_steps", type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
+ OutputParam("original_latents", type_hint=torch.Tensor),
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
]
def __call__(self, components, state: PipelineState):
# ... existing logic ...
+ # Process change map and create masks
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
+ block_state.original_latents = block_state.latents
```
### denoise
The `before_denoiser` sub-block requires the following changes.
- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block
- update the computation in the `__call__` method for applying Differential Diffusion
```diff
class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
@property
def description(self) -> str:
return (
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
)
@property
def inputs(self) -> List[str]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
+ InputParam("denoising_start"),
+ InputParam("original_latents", type_hint=torch.Tensor),
+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
]
def __call__(self, components, block_state, i, t):
+ # Apply differential diffusion logic
+ if i == 0 and block_state.denoising_start is None:
+ block_state.latents = block_state.original_latents[:1]
+ else:
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
# ... rest of existing logic ...
```
## Assembling the blocks
You should have all the blocks you need at this point to create a [`ModularPipeline`].
Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter.
Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified.
Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`.
```py
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
print(dd_blocks)
```
## ModularPipeline
Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
```py
from diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_componenets(torch_dtype=torch.float16)
dd_pipeline.to("cuda")
```
## Adding workflows
Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch.
This section demonstrates how to add an IP-Adapter or ControlNet.
### IP-Adapter
Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
```py
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
```
Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`.
```py
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
```
Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
```py
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
dd_pipeline.load_components(torch_dtype=torch.float16)
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
dd_pipeline.loader.set_ip_adapter_scale(0.6)
dd_pipeline = dd_pipeline.to(device)
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
prompt = "a green pear"
negative_prompt = "blurry"
generator = torch.Generator(device=device).manual_seed(42)
image = dd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
generator=generator,
ip_adapter_image=ip_adapter_image,
diffdiff_map=mask,
image=image,
output="images"
)[0]
```
### ControlNet
Stable Diffusion XL already has a preset ControlNet block that can readily be used.
```py
from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
control_input_block = StableDiffusionXLAutoControlNetInputStep()
```
However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet.
Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`.
```py
class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
```
Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
```py
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
dd_pipeline.load_components(torch_dtype=torch.float16)
dd_pipeline = dd_pipeline.to(device)
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
prompt = "a green pear"
negative_prompt = "blurry"
generator = torch.Generator(device=device).manual_seed(42)
image = dd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
generator=generator,
control_image=control_image,
controlnet_conditioning_scale=0.5,
diffdiff_map=mask,
image=image,
output="images"
)[0]
```
### AutoPipelineBlocks
The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion.
Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used.
```py
class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
block_names = ["controlnet_denoise", "denoise"]
block_trigger_inputs = ["controlnet_cond", None]
```
Add the `ip_adapter` and `controlnet_input` blocks.
```py
DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
```
Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run.
```py
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
dd_pipeline.load_components(torch_dtype=torch.float16)
```
## Share
Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`.
```py
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
```
Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`].
## Run a pipeline
[`ModularPipeline`] is the main interface for loading, running, and managing modular pipelines.
```py
import torch
from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
from diffusers import ModularPipeline, ComponentsManager
components = ComponentsManager()
# Use ComponentsManager to enable auto CPU offloading for memory efficiency
manager = ComponentsManager()
manager.enable_auto_cpu_offload(device="cuda:0")
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
diffdiff_pipeline.load_components(torch_dtype=torch.float16)
pipe = ModularPipeline.from_pretrained("Qwen/Qwen-Image", components_manager=manager)
pipe.load_components(torch_dtype=torch.bfloat16)
image = pipe(
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
).images[0]
image
```
[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.
> [!TIP]
> [`ComponentsManager`] with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.
Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.
## Understand the structure
A [`ModularPipeline`] has two parts:
- **State**: the loaded components (models, schedulers, processors) and configuration
- **Definition**: the [`ModularPipelineBlocks`] that specify inputs, outputs, expected components and computation logic
The blocks define *what* the pipeline does. Access them through `pipe.blocks`.
```py
print(pipe.blocks)
```
```
QwenImageAutoBlocks(
Class: SequentialPipelineBlocks
Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `prompt`, `image`
- `inpainting`: requires `prompt`, `mask_image`, `image`
- `controlnet_text2image`: requires `prompt`, `control_image`
...
Components:
text_encoder (`Qwen2_5_VLForConditionalGeneration`)
vae (`AutoencoderKLQwenImage`)
transformer (`QwenImageTransformer2DModel`)
...
Sub-Blocks:
[0] text_encoder (QwenImageAutoTextEncoderStep)
[1] vae_encoder (QwenImageAutoVaeEncoderStep)
[2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep)
[3] denoise (QwenImageAutoCoreDenoiseStep)
[4] decode (QwenImageAutoDecodeStep)
)
```
The output returns:
- The supported workflows (text2image, image2image, inpainting, etc.)
- The Sub-Blocks it's composed of (text_encoder, vae_encoder, denoise, decode)
### Workflows
`QwenImageAutoBlocks` is a [`ConditionalPipelineBlocks`], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Let's see this in action with an example.
```py
from diffusers.utils import load_image
input_image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
image = pipe(
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
image=input_image,
).images[0]
```
Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow.
```py
img2img_blocks = pipe.blocks.get_workflow("image2image")
```
Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
### Sub-blocks
Blocks can contain other blocks. `pipe.blocks` gives you the top-level block definition (here, `QwenImageAutoBlocks`), while `sub_blocks` lets you access the smaller blocks inside it.
`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property.
The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
```py
vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
print(vae_encoder_block.doc)
```
This block can be converted to a pipeline so that it can run on its own with [`~ModularPipelineBlocks.init_pipeline`].
```py
vae_encoder_pipe = vae_encoder_block.init_pipeline()
# Reuse the VAE we already loaded, we can reuse it with update_components() method
vae_encoder_pipe.update_components(vae=pipe.vae)
# Run just this block
image_latents = vae_encoder_pipe(image=input_image).image_latents
print(image_latents.shape)
```
It reuses the VAE from our original pipeline instead of reloading it, keeping memory usage efficient. Learn more in the [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guide.
Since blocks are composable, you can modify the pipeline's definition by adding, removing, or swapping blocks to create new workflows. In the next section, we'll add a canny edge detection block to a ControlNet pipeline, so you can pass a regular image instead of a pre-processed canny edge map.
## Compose new workflows
Let's add a canny edge detection block to a ControlNet pipeline. First, load a pre-built canny block from the Hub (see [Building Custom Blocks](https://huggingface.co/docs/diffusers/modular_diffusers/custom_blocks) to create your own).
```py
from diffusers.modular_pipelines import ModularPipelineBlocks
# Load a canny block from the Hub
canny_block = ModularPipelineBlocks.from_pretrained(
"diffusers-internal-dev/canny-filtering",
trust_remote_code=True,
)
print(canny_block.doc)
```
```
class CannyBlock
Inputs:
image (`Union[Image, ndarray]`):
Image to compute canny filter on
low_threshold (`int`, *optional*, defaults to 50):
Low threshold for the canny filter.
high_threshold (`int`, *optional*, defaults to 200):
High threshold for the canny filter.
...
Outputs:
control_image (`PIL.Image`):
Canny map for input image
```
UUse `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
```py
# Get the controlnet workflow that we want to work with
blocks = pipe.blocks.get_workflow("controlnet_text2image")
print(blocks.doc)
```
```
class SequentialPipelineBlocks
Inputs:
prompt (`str`):
The prompt or prompts to guide image generation.
control_image (`Image`):
Control image for ControlNet conditioning.
...
```
The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) - a multi-block type where blocks run one after another and data flows linearly from one block to the next. Each block's `intermediate_outputs` become available as `inputs` to subsequent blocks.
Currently this workflow requires `control_image` as input. Let's insert the canny block at the beginning so the pipeline accepts a regular image instead.
```py
# Insert canny at the beginning
blocks.sub_blocks.insert("canny", canny_block, 0)
# Check the updated structure: CannyBlock is now listed as first sub-block
print(blocks)
# Check the updated doc
print(blocks.doc)
```
```
class SequentialPipelineBlocks
Inputs:
image (`Union[Image, ndarray]`):
Image to compute canny filter on
low_threshold (`int`, *optional*, defaults to 50):
Low threshold for the canny filter.
high_threshold (`int`, *optional*, defaults to 200):
High threshold for the canny filter.
prompt (`str`):
The prompt or prompts to guide image generation.
...
```
Now the pipeline takes `image` as input instead of `control_image`. Because blocks in a sequence share data automatically, the canny block's output (`control_image`) flows to the denoise block that needs it, and the canny block's input (`image`) becomes a pipeline input since no earlier block provides it.
Create a pipeline from the modified blocks and load a ControlNet model.
```py
pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager)
pipeline.load_components(torch_dtype=torch.bfloat16)
# Load the ControlNet model
controlnet_spec = pipeline.get_component_spec("controlnet")
controlnet_spec.pretrained_model_name_or_path = "InstantX/Qwen-Image-ControlNet-Union"
controlnet = controlnet_spec.load(torch_dtype=torch.bfloat16)
pipeline.update_components(controlnet=controlnet)
```
Now run the pipeline - the canny block preprocesses the image for ControlNet.
```py
from diffusers.utils import load_image
prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
output = pipeline(
prompt=prompt,
image=image,
).images[0]
output
```
## Next steps
<hfoptions id="next">
<hfoption id="Build custom blocks">
Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.
</hfoption>
<hfoption id="Share components">
Use [`ComponentsManager`](./components_manager) to share models across multiple pipelines and manage memory efficiently.
</hfoption>
<hfoption id="Visual interface">
Connect modular pipelines to [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in the Mellon guide.
</hfoption>
</hfoptions>

View File

@@ -39,8 +39,11 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
combine_inputs,
combine_outputs,
format_components,
format_configs,
format_workflow,
make_doc_string,
)
@@ -243,6 +246,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_workflow_map = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -298,6 +302,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def outputs(self) -> List[OutputParam]:
return self._get_outputs()
# currentlyonly ConditionalPipelineBlocks and SequentialPipelineBlocks support `get_execution_blocks`
def get_execution_blocks(self, **kwargs):
"""
Get the block(s) that would execute given the inputs. Must be implemented by subclasses that support
conditional block selection.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
"""
raise NotImplementedError(f"`get_execution_blocks` is not implemented for {self.__class__.__name__}")
# currently only SequentialPipelineBlocks support workflows
@property
def workflow_names(self):
"""
Returns a list of available workflow names. Must be implemented by subclasses that define `_workflow_map`.
"""
raise NotImplementedError(f"`workflow_names` is not implemented for {self.__class__.__name__}")
def get_workflow(self, workflow_name: str):
"""
Get the execution blocks for a specific workflow. Must be implemented by subclasses that define
`_workflow_map`.
Args:
workflow_name: Name of the workflow to retrieve.
"""
raise NotImplementedError(f"`get_workflow` is not implemented for {self.__class__.__name__}")
@classmethod
def from_pretrained(
cls,
@@ -435,72 +468,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
@staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
current default value is None and new default value is not None. Warns if multiple non-None default values
exist for the same input.
Args:
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
Returns:
List[InputParam]: Combined list of unique InputParam objects
"""
combined_dict = {} # name -> InputParam
value_sources = {} # name -> block_name
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (
current_param.default is not None
and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn(
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
@staticmethod
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
"""
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
occurrence of each output name.
Args:
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
Returns:
List[OutputParam]: Combined list of unique OutputParam objects
"""
combined_dict = {} # name -> OutputParam
for block_name, outputs in named_output_lists:
for output_param in outputs:
if (output_param.name not in combined_dict) or (
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs if input_param.name is not None]
@@ -532,7 +499,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
class ConditionalPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
`select_block` method to define the logic for selecting the block.
`select_block` method to define the logic for selecting the block. Currently, we only support selection logic based
on the presence or absence of inputs (i.e., whether they are `None` or not)
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -540,15 +508,20 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
block_names: List of prefixes for each block
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
block_classes: List of block classes to be used. Must have the same length as `block_names`.
block_names: List of names for each block. Must have the same length as `block_classes`.
block_trigger_inputs: List of input names that `select_block()` uses to determine which block to run.
For `ConditionalPipelineBlocks`, this does not need to correspond to `block_names` and `block_classes`. For
`AutoPipelineBlocks`, this must have the same length as `block_names` and `block_classes`, where each
element specifies the trigger input for the corresponding block.
default_block_name: Name of the default block to run when no trigger inputs match.
If None, this block can be skipped entirely when no trigger inputs are provided.
"""
block_classes = []
block_names = []
block_trigger_inputs = []
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
default_block_name = None
def __init__(self):
sub_blocks = InsertableDict()
@@ -612,7 +585,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
@property
def inputs(self) -> List[Tuple[str, Any]]:
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
combined_inputs = self.combine_inputs(*named_inputs)
combined_inputs = combine_inputs(*named_inputs)
# mark Required inputs only if that input is required by all the blocks
for input_param in combined_inputs:
if input_param.name in self.required_inputs:
@@ -624,15 +597,16 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
@property
def outputs(self) -> List[str]:
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
# used for `__repr__`
def _get_trigger_inputs(self) -> set:
"""
Returns a set of all unique trigger input values found in this block and nested blocks.
@@ -661,11 +635,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
return all_triggers
@property
def trigger_inputs(self):
"""All trigger inputs including from nested blocks."""
return self._get_trigger_inputs()
def select_block(self, **kwargs) -> Optional[str]:
"""
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
@@ -705,6 +674,39 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
logger.error(error_msg)
raise
def get_execution_blocks(self, **kwargs) -> Optional["ModularPipelineBlocks"]:
"""
Get the block(s) that would execute given the inputs.
Recursively resolves nested ConditionalPipelineBlocks until reaching either:
- A leaf block (no sub_blocks) → returns single `ModularPipelineBlocks`
- A `SequentialPipelineBlocks` → delegates to its `get_execution_blocks()` which returns
a `SequentialPipelineBlocks` containing the resolved execution blocks
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Returns:
- `ModularPipelineBlocks`: A leaf block or resolved `SequentialPipelineBlocks`
- `None`: If this block would be skipped (no trigger matched and no default)
"""
trigger_kwargs = {name: kwargs.get(name) for name in self.block_trigger_inputs if name is not None}
block_name = self.select_block(**trigger_kwargs)
if block_name is None:
block_name = self.default_block_name
if block_name is None:
return None
block = self.sub_blocks[block_name]
# Recursively resolve until we hit a leaf block or a SequentialPipelineBlocks
if block.sub_blocks:
return block.get_execution_blocks(**kwargs)
return block
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
@@ -712,11 +714,11 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
)
if self.trigger_inputs:
if self._get_trigger_inputs():
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
header += f" Trigger Inputs: {sorted(self._get_trigger_inputs())}\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -783,24 +785,56 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
class AutoPipelineBlocks(ConditionalPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
This is a specialized version of `ConditionalPipelineBlocks` where:
- Each block has one corresponding trigger input (1:1 mapping)
- Block selection is automatic: the first block whose trigger input is present gets selected
- `block_trigger_inputs` must have the same length as `block_names` and `block_classes`
- Use `None` in `block_trigger_inputs` to specify the default block, i.e the block that will run if no trigger
inputs are present
Attributes:
block_classes:
List of block classes to be used. Must have the same length as `block_names` and
`block_trigger_inputs`.
block_names:
List of names for each block. Must have the same length as `block_classes` and `block_trigger_inputs`.
block_trigger_inputs:
List of input names where each element specifies the trigger input for the corresponding block. Use
`None` to mark the default block.
Example:
```python
class MyAutoBlock(AutoPipelineBlocks):
block_classes = [InpaintEncoderBlock, ImageEncoderBlock, TextEncoderBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask_image", "image", None] # text2img is the default
```
With this definition:
- As long as `mask_image` is provided, "inpaint" block runs (regardless of `image` being provided or not)
- If `mask_image` is not provided but `image` is provided, "img2img" block runs
- Otherwise, "text2img" block runs (default, trigger is `None`)
"""
def __init__(self):
super().__init__()
if self.default_block_name is not None:
raise ValueError(
f"In {self.__class__.__name__}, do not set `default_block_name` for AutoPipelineBlocks. "
f"Use `None` in `block_trigger_inputs` to specify the default block."
)
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
)
@property
def default_block_name(self) -> Optional[str]:
"""Derive default_block_name from block_trigger_inputs (None entry)."""
if None in self.block_trigger_inputs:
idx = self.block_trigger_inputs.index(None)
return self.block_names[idx]
return None
self.default_block_name = self.block_names[idx]
def select_block(self, **kwargs) -> Optional[str]:
"""Select block based on which trigger input is present (not None)."""
@@ -854,6 +888,29 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs.append(config)
return expected_configs
@property
def workflow_names(self):
if self._workflow_map is None:
raise NotImplementedError(
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
)
return list(self._workflow_map.keys())
def get_workflow(self, workflow_name: str):
if self._workflow_map is None:
raise NotImplementedError(
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
)
if workflow_name not in self._workflow_map:
raise ValueError(f"Workflow {workflow_name} not found in {self.__class__.__name__}")
trigger_inputs = self._workflow_map[workflow_name]
workflow_blocks = self.get_execution_blocks(**trigger_inputs)
return workflow_blocks
@classmethod
def from_blocks_dict(
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
@@ -949,7 +1006,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# filter out them here so they do not end up as intermediate_outputs
if name not in inp_names:
named_outputs.append((name, block.intermediate_outputs))
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
# YiYi TODO: I think we can remove the outputs property
@@ -973,6 +1030,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
raise
return pipeline, state
# used for `trigger_inputs` property
def _get_trigger_inputs(self):
"""
Returns a set of all unique trigger input values found in the blocks.
@@ -996,89 +1054,50 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return fn_recursive_get_trigger(self.sub_blocks)
@property
def trigger_inputs(self):
return self._get_trigger_inputs()
def _traverse_trigger_blocks(self, active_inputs):
def get_execution_blocks(self, **kwargs) -> "SequentialPipelineBlocks":
"""
Traverse blocks and select which ones would run given the active inputs.
Get the blocks that would execute given the specified inputs.
Args:
active_inputs: Dict of input names to values that are "present"
**kwargs: Input names and values. Only trigger inputs affect block selection.
Returns:
OrderedDict of block_name -> block that would execute
SequentialPipelineBlocks containing only the blocks that would execute
"""
# Copy kwargs so we can add outputs as we traverse
active_inputs = dict(kwargs)
def fn_recursive_traverse(block, block_name, active_inputs):
result_blocks = OrderedDict()
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
if isinstance(block, ConditionalPipelineBlocks):
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
selected_block_name = block.select_block(**trigger_kwargs)
if selected_block_name is None:
selected_block_name = block.default_block_name
if selected_block_name is None:
block = block.get_execution_blocks(**active_inputs)
if block is None:
return result_blocks
selected_block = block.sub_blocks[selected_block_name]
if selected_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
else:
result_blocks[block_name] = selected_block
if hasattr(selected_block, "outputs"):
for out in selected_block.outputs:
active_inputs[out.name] = True
return result_blocks
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
if block.sub_blocks:
# Has sub_blocks (SequentialPipelineBlocks/ConditionalPipelineBlocks)
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
for sub_block_name, sub_block in block.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
nested_blocks = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
nested_blocks = {f"{block_name}.{k}": v for k, v in nested_blocks.items()}
result_blocks.update(nested_blocks)
else:
# Leaf block: single ModularPipelineBlocks or LoopSequentialPipelineBlocks
result_blocks[block_name] = block
if hasattr(block, "outputs"):
for out in block.outputs:
# Add outputs to active_inputs so subsequent blocks can use them as triggers
if hasattr(block, "intermediate_outputs"):
for out in block.intermediate_outputs:
active_inputs[out.name] = True
return result_blocks
all_blocks = OrderedDict()
for block_name, block in self.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(blocks_to_update)
return all_blocks
nested_blocks = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(nested_blocks)
def get_execution_blocks(self, **kwargs):
"""
Get the blocks that would execute given the specified inputs.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Pass any inputs that would be non-None at runtime.
Returns:
SequentialPipelineBlocks containing only the blocks that would execute
Example:
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
image=image)
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
"""
# Filter out None values
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
return SequentialPipelineBlocks.from_blocks_dict(all_blocks)
def __repr__(self):
class_name = self.__class__.__name__
@@ -1087,18 +1106,23 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
)
if self.trigger_inputs:
if self._workflow_map is None and self._get_trigger_inputs():
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
header += f" Trigger Inputs: {[inp for inp in self._get_trigger_inputs() if inp is not None]}\n"
# Get first trigger input as example
example_input = next(t for t in self.trigger_inputs if t is not None)
example_input = next(t for t in self._get_trigger_inputs() if t is not None)
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
header += " " + "=" * 100 + "\n\n"
description = self.description
if self._workflow_map is not None:
workflow_str = format_workflow(self._workflow_map)
description = f"{self.description}\n\n{workflow_str}"
# Format description with proper indentation
desc_lines = self.description.split("\n")
desc_lines = description.split("\n")
desc = []
# First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}")
@@ -1146,10 +1170,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
@property
def doc(self):
description = self.description
if self._workflow_map is not None:
workflow_str = format_workflow(self._workflow_map)
description = f"{self.description}\n\n{workflow_str}"
return make_doc_string(
self.inputs,
self.outputs,
self.description,
description=description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs,
@@ -1282,7 +1311,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
for output in self.loop_intermediate_outputs:
if output.name not in {output.name for output in combined_outputs}:
combined_outputs.append(output)

View File

@@ -14,9 +14,10 @@
import inspect
import re
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import PIL.Image
import torch
@@ -860,6 +861,30 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
return "\n".join(formatted_configs)
def format_workflow(workflow_map):
"""Format a workflow map into a readable string representation.
Args:
workflow_map: Dictionary mapping workflow names to trigger inputs
Returns:
A formatted string representing all workflows
"""
if workflow_map is None:
return ""
lines = ["Supported workflows:"]
for workflow_name, trigger_inputs in workflow_map.items():
required_inputs = [k for k, v in trigger_inputs.items() if v]
if required_inputs:
inputs_str = ", ".join(f"`{t}`" for t in required_inputs)
lines.append(f" - `{workflow_name}`: requires {inputs_str}")
else:
lines.append(f" - `{workflow_name}`: default (no additional inputs required)")
return "\n".join(lines)
def make_doc_string(
inputs,
outputs,
@@ -914,3 +939,69 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)
return output
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
same input.
Args:
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
Returns:
List[InputParam]: Combined list of unique InputParam objects
"""
combined_dict = {} # name -> InputParam
value_sources = {} # name -> block_name
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (
current_param.default is not None
and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn(
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
"""
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
occurrence of each output name.
Args:
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
Returns:
List[OutputParam]: Combined list of unique OutputParam objects
"""
combined_dict = {} # name -> OutputParam
for block_name, outputs in named_output_lists:
for output_param in outputs:
if (output_param.name not in combined_dict) or (
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())

View File

@@ -1113,10 +1113,14 @@ AUTO_BLOCKS = InsertableDict(
class QwenImageAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
- for image-to-image generation, you need to provide `image`
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
- to run the controlnet workflow, you need to provide `control_image`
- for text-to-image generation, all you need to provide is `prompt`
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `prompt`, `image`
- `inpainting`: requires `prompt`, `mask_image`, `image`
- `controlnet_text2image`: requires `prompt`, `control_image`
- `controlnet_image2image`: requires `prompt`, `image`, `control_image`
- `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image`
Components:
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
@@ -1197,15 +1201,24 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
# Workflow map defines the trigger conditions for each workflow.
# How to define:
# - Only include required inputs and trigger inputs (inputs that determine which blocks run)
# - `True` means the workflow triggers when the input is not None (most common case)
# - Use specific values (e.g., `{"strength": 0.5}`) if your `select_block` logic depends on the value
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"prompt": True, "image": True},
"inpainting": {"prompt": True, "mask_image": True, "image": True},
"controlnet_text2image": {"prompt": True, "control_image": True},
"controlnet_image2image": {"prompt": True, "image": True, "control_image": True},
"controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True},
}
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage."
@property
def outputs(self):

View File

@@ -227,7 +227,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_aiter_available, _aiter_version = _is_package_available("aiter")
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")