Compare commits

...

22 Commits

Author SHA1 Message Date
yiyixuxu
741685def1 add a warn for mellon and add new guides to overview 2026-01-31 10:09:31 +01:00
yiyixuxu
3cd9ff4296 style 2026-01-31 09:51:35 +01:00
yiyixuxu
9d7f6db9ec update 2026-01-31 09:51:13 +01:00
yiyixuxu
bf6a07e665 update custom block guide 2026-01-31 09:36:26 +01:00
yiyixuxu
e3a4cc5730 fix components manager 2026-01-31 03:53:29 +01:00
yiyixuxu
391c410368 up up 2026-01-31 02:19:27 +01:00
yiyixuxu
3985c43031 fix more 2026-01-30 13:09:46 +01:00
yiyixuxu
8c5b119e52 add quant info to components manager 2026-01-30 10:05:05 +01:00
yiyixuxu
46a713a6fa update doc 2026-01-30 02:10:31 +01:00
yiyixuxu
d4f2a8979f mellon_type -> inpnt_types + output_types 2026-01-30 02:10:19 +01:00
yiyixuxu
5c7273ff99 Merge branch 'more-mellon-related' of github.com:huggingface/diffusers into more-mellon-related 2026-01-29 21:27:02 +01:00
yiyixuxu
3fe2711691 style 2026-01-29 21:26:49 +01:00
yiyixuxu
48160f6f5e add mellon_types 2026-01-29 21:26:11 +01:00
YiYi Xu
3393ef0177 Merge branch 'main' into more-mellon-related 2026-01-29 09:16:13 -10:00
yiyixuxu
a71d86b9ae style 2026-01-29 05:30:24 +01:00
yiyixuxu
26f59f1aa9 add to toctree 2026-01-29 05:21:47 +01:00
yiyixuxu
29c5741c2a add mellon guide 2026-01-29 05:20:56 +01:00
yiyixuxu
5ad83903f9 up up fix 2026-01-29 03:45:11 +01:00
yiyixuxu
ffc5708b78 style 2026-01-29 02:07:56 +01:00
yiyixuxu
c5c732b87b add from_custom_block 2026-01-29 02:05:53 +01:00
yiyixuxu
d2bee6a57e refactor mellonparam: move the template outside, add metaclass, define some generic template for custom node 2026-01-28 22:24:16 +01:00
yiyixuxu
2890dd8480 add metadata field to input/output param 2026-01-28 22:23:17 +01:00
7 changed files with 941 additions and 907 deletions

View File

@@ -114,6 +114,8 @@
title: Guiders
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
- local: modular_diffusers/mellon
title: Mellon Guide
title: Modular Diffusers
- isExpanded: false
sections:

View File

@@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License.
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
> [!TIP]
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom blocks.
## Project Structure
@@ -31,18 +31,58 @@ Your custom block project should use the following structure:
- `block.py` contains the custom block implementation
- `modular_config.json` contains the metadata needed to load the block
## Example: Florence 2 Inpainting Block
## Quick Start with Template
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
The fastest way to create a custom block is to start from our template:
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
### 1. Download the template
```py
```python
from diffusers import ModularPipelineBlocks
model_id = "diffusers/custom-block-template"
local_dir = model_id.split("/")[-1]
blocks = ModularPipelineBlocks.from_pretrained(
model_id,
trust_remote_code=True,
local_dir=local_dir
)
```
This saves the template files to `custom-block-template/` locally. Feel free to use a custom `local_dir`.
### 2. Edit locally
Open `block.py` and implement your custom block. The template includes commented examples showing how to define each property. See the [Florence-2 example](#example-florence-2-image-annotator) below for a complete implementation.
### 3. Test your block
```python
from diffusers import ModularPipelineBlocks
blocks = ModularPipelineBlocks.from_pretrained(local_dir, trust_remote_code=True)
pipeline = blocks.init_pipeline()
output = pipeline(...) # your inputs here
```
### 4. Upload to the Hub
```python
pipeline.save_pretrained(local_dir, repo_id="your-username/your-block-name", push_to_hub=True)
```
## Example: Florence-2 Image Annotator
This example creates a custom block that uses [Florence-2](https://huggingface.co/docs/transformers/model_doc/florence2) to process an input image and generate a mask for inpainting.
### Define components
First, define the components the block needs. Here we use `Florence2ForConditionalGeneration` and its processor. When defining components, specify the `name` (how you'll access it in code), `type_hint` (the model class), and `pretrained_model_name_or_path` (where to load weights from).
```python
# Inside block.py
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ComponentSpec,
)
from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec
from transformers import AutoProcessor, Florence2ForConditionalGeneration
@@ -64,40 +104,19 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
]
```
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
### Define inputs and outputs
```py
Next, define the block's interface. Inputs include the image, annotation task, and prompt. Outputs include the generated mask and annotations.
```python
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
from PIL import Image
from diffusers.modular_pipelines import InputParam, OutputParam
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
# ... expected_components from above ...
@property
def inputs(self) -> List[InputParam]:
@@ -110,51 +129,21 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
type_hint=str,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
description="Annotation task to perform (e.g., <OD>, <CAPTION>, <REFERRING_EXPRESSION_SEGMENTATION>)",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
type_hint=str,
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
description="Prompt to provide context for the annotation task",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
description="Output type: 'mask_image', 'mask_overlay', or 'bounding_box'",
),
]
@@ -163,225 +152,45 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
type_hint=Image.Image,
description="Inpainting mask for the input image",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
description="Raw annotation predictions",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
type_hint=Image.Image,
description="Annotated image",
),
]
```
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
### Implement the `__call__` method
```py
from typing import List, Union
from PIL import Image, ImageDraw
The `__call__` method contains the block's logic. Access inputs via `block_state`, run your computation, and set outputs back to `block_state`.
```python
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
from diffusers.modular_pipelines import PipelineState
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
def get_annotations(self, components, images, prompts, task):
task_prompts = [task + prompt for prompt in prompts]
inputs = components.image_annotator_processor(
text=task_prompts, images=images, return_tensors="pt"
).to(components.image_annotator.device, components.image_annotator.dtype)
generated_ids = components.image_annotator.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
annotations = components.image_annotator_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
outputs = []
for image, annotation in zip(images, annotations):
outputs.append(
components.image_annotator_processor.post_process_generation(
annotation, task=task, image_size=(image.width, image.height)
)
)
return outputs
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
masks = []
for image, annotation in zip(images, annotations):
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_image)
for _, _annotation in annotation.items():
if "polygons" in _annotation:
for polygon in _annotation["polygons"]:
polygon = np.array(polygon).reshape(-1, 2)
if len(polygon) < 3:
continue
polygon = polygon.reshape(-1).tolist()
draw.polygon(polygon, fill=fill)
elif "bbox" in _annotation:
bbox = _annotation["bbox"]
draw.rectangle(bbox, fill="white")
masks.append(mask_image)
return masks
def prepare_bounding_boxes(self, images, annotations):
outputs = []
for image, annotation in zip(images, annotations):
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
for _, _annotation in annotation.items():
bbox = _annotation["bbox"]
label = _annotation["label"]
draw.rectangle(bbox, outline="red", width=3)
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
outputs.append(image_copy)
return outputs
def prepare_inputs(self, images, prompts):
prompts = prompts or ""
if isinstance(images, Image.Image):
images = [images]
if isinstance(prompts, str):
prompts = [prompts]
if len(images) != len(prompts):
raise ValueError("Number of images and annotation prompts must match.")
return images, prompts
# ... expected_components, inputs, intermediate_outputs from above ...
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
images, annotation_task_prompt = self.prepare_inputs(
block_state.image, block_state.annotation_prompt
)
task = block_state.annotation_task
fill = block_state.fill
annotations = self.get_annotations(
components, images, annotation_task_prompt, task
)
@@ -400,67 +209,69 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
# Helper methods for mask/bounding box generation...
```
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
<hfoptions id="share">
<hfoption id="hf CLI">
```shell
# In the folder with the `block.py` file, run:
diffusers-cli custom_block
```
Then upload the block to the Hub:
```shell
hf upload <your repo id> . .
```
</hfoption>
<hfoption id="push_to_hub">
```py
from block import Florence2ImageAnnotatorBlock
block = Florence2ImageAnnotatorBlock()
block.push_to_hub("<your repo id>")
```
</hfoption>
</hfoptions>
> [!TIP]
> See the complete implementation at [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator).
## Using Custom Blocks
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
Load a custom block with [`~ModularPipeline.from_pretrained`] and set `trust_remote_code=True`.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers import ModularPipeline
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
# Load the Florence-2 annotator pipeline
image_annotator = ModularPipeline.from_pretrained(
"diffusers/Florence2-image-Annotator",
trust_remote_code=True
)
my_blocks = INPAINT_BLOCKS.copy()
# insert the annotation block before the image encoding step
my_blocks.insert("image_annotator", image_annotator_block, 1)
# Check the docstring to see inputs/outputs
print(image_annotator.blocks.doc)
```
# Create our initial set of inpainting blocks
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
Use the block to generate a mask:
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
pipe = blocks.init_pipeline(repo_id)
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
```python
image_annotator.load_components(torch_dtype=torch.bfloat16)
image_annotator.to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg")
image = image.resize((1024, 1024))
prompt = ["A red car"]
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
annotation_prompt = ["the car"]
mask_image = image_annotator_node(
prompt=prompt,
image=image,
annotation_task=annotation_task,
annotation_prompt=annotation_prompt,
annotation_output_type="mask_image",
).images
mask_image[0].save("car-mask.png")
```
You can also compose it with other blocks to create a new pipeline:
```python
# Get the annotator block
annotator_block = image_annotator.blocks
# Get an inpainting workflow and insert the annotator at the beginning
inpaint_blocks = ModularPipeline.from_pretrained("Qwen/Qwen-Image").blocks.get_workflow("inpainting")
inpaint_blocks.sub_blocks.insert("image_annotator", annotator_block, 0)
# Initialize the combined pipeline
pipe = inpaint_blocks.init_pipeline()
pipe.load_components(torch_dtype=torch.float16, device="cuda")
# Now the pipeline automatically generates masks from prompts
output = pipe(
prompt=prompt,
image=image,
@@ -477,16 +288,39 @@ output[0].save("florence-inpainting.png")
## Editing Custom Blocks
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
You can edit any existing custom block by downloading it locally. This follows the same workflow as the [Quick Start with Template](#quick-start-with-template), but starting from an existing block instead of the template.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
Use the `local_dir` argument to download a custom block to a specific folder:
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
```python
from diffusers import ModularPipelineBlocks
# Download to a local folder for editing
annotator_block = ModularPipelineBlocks.from_pretrained(
"diffusers/Florence2-image-Annotator",
trust_remote_code=True,
local_dir="./my-florence-block"
)
```
Any changes made to the block files in this folder will be reflected when you load the block again.
Any changes made to the block files in this folder will be reflected when you load the block again. When you're ready to share your changes, upload to a new repository:
```python
pipeline = annotator_block.init_pipeline()
pipeline.save_pretrained("./my-florence-block", repo_id="your-username/my-custom-florence", push_to_hub=True)
```
## Next Steps
<hfoptions id="next">
<hfoption id="Use in Mellon">
Make your custom block work with Mellon's visual interface - no UI code required. See the [Mellon Custom Blocks](./mellon_custom_blocks) guide.
</hfoption>
<hfoption id="Explore existing blocks">
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
</hfoption>
</hfoptions>

View File

@@ -0,0 +1,236 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
## Using Custom Blocks with Mellon
[Mellon](https://github.com/cubiq/Mellon) is a visual workflow interface that integrates with Modular Diffusers and is designed for node-based workflows.
> [!WARNING]
> Mellon is in early development and not yet ready for production use. Consider this a sneak peek of how the integration works! Custom blocks built with Modular Diffusers work with Mellon out of the box - no UI code required - and we'll ensure compatibility as Mellon evolves.
## Overview
To use a custom block in Mellon, you need a `mellon_pipeline_config.json` file that defines how your block's parameters map to Mellon UI components. Here's how to create one:
1. **Add a "Mellon type" to your block's parameters** - Each `InputParam`/`OutputParam` needs a type that tells Mellon what UI component to render (e.g., `"textbox"`, `"dropdown"`, `"image"`). You can specify types via metadata in your block definitions, or pass them when generating the config.
2. **Generate `mellon_pipeline_config.json`** - Use our utility to generate a default template and push it to your Hub repository
3. **(Optional) Manually adjust the template** - Fine-tune the generated config for your specific needs
## Step 1: Specify Mellon Types for Parameters
Mellon types determine how each parameter renders in the UI. If you don't specify a type for a parameter, it will default to `"custom"`, which renders as a simple connection dot. You can always adjust this later in the generated config.
### Supported Mellon Types
| Type | Input/Output | Description |
|------|--------------|-------------|
| `image` | Both | Image (PIL Image) |
| `video` | Both | Video |
| `text` | Both | Text display |
| `textbox` | Input | Text input |
| `dropdown` | Input | Dropdown selection menu |
| `slider` | Input | Slider for numeric values |
| `number` | Input | Numeric input |
| `checkbox` | Input | Boolean toggle |
### Method 1: Using `metadata` in Block Definitions
If you're defining a custom block from scratch, you can add `metadata={"mellon": "<type>"}` directly to your `InputParam` and `OutputParam` definitions:
```python
class GeminiPromptExpander(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
metadata={"mellon": "textbox"}, # Text input
)
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt",
type_hint=str,
description="Expanded prompt by the LLM",
metadata={"mellon": "text"}, # Text output
),
OutputParam(
"old_prompt",
type_hint=str,
description="Old prompt provided by the user",
# No metadata - we don't want to render this in UI
)
]
```
### Method 2: Using `input_types` and `output_types` When Generating Config
If you're working with an existing pipeline or prefer to keep your block definitions clean, you can specify types when generating the config using the `input_types/output_types` argument:
```python
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig
mellon_config = MellonPipelineConfig.from_custom_block(
blocks,
input_types={"prompt": "textbox"},
output_types={"prompt": "text"}
)
```
> [!NOTE]
> If you specify both `metadata` and `input_types`/`output_types`, the arguments take precedence, allowing you to override metadata when needed.
## Step 2: Generate and Push the Mellon Config
After adding metadata to your block, generate the default Mellon configuration template and push it to the Hub:
```python
from diffusers import ModularPipelineBlocks
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig
# load your custom blocks from your local dir
blocks = ModularPipelineBlocks.from_pretrained("/path/local/folder", trust_remote_code=True)
# Generate the default config template
mellon_config = MellonPipelineConfig.from_custom_block(blocks)
# push the default template to `repo_id`, you will need to pass the same local folder path so that it will save the config locally first
mellon_config.save(
local_dir="/path/local/folder",
repo_id= repo_id,
push_to_hub=True
)
```
This creates a `mellon_pipeline_config.json` file in your repository.
## Step 3: Review and Adjust the Config (Optional)
The generated template is a starting point - you may want to adjust it for your needs. Let's walk through the generated config for the Gemini Prompt Expander:
```json
{
"label": "Gemini Prompt Expander",
"default_repo": "",
"default_dtype": "",
"node_params": {
"custom": {
"params": {
"prompt": {
"label": "Prompt",
"type": "string",
"display": "textarea",
"default": ""
},
"out_prompt": {
"label": "Prompt",
"type": "string",
"display": "output"
},
"old_prompt": {
"label": "Old Prompt",
"type": "custom",
"display": "output"
},
"doc": {
"label": "Doc",
"type": "string",
"display": "output"
}
},
"input_names": ["prompt"],
"model_input_names": [],
"output_names": ["out_prompt", "old_prompt", "doc"],
"block_name": "custom",
"node_type": "custom"
}
}
}
```
### Understanding the Structure
The `params` dict defines how each UI element renders. The `input_names`, `model_input_names`, and `output_names` lists map these UI elements to the underlying [`ModularPipelineBlocks`]'s I/O interface:
| Mellon Config | ModularPipelineBlocks |
|---------------|----------------------|
| `input_names` | `inputs` property |
| `model_input_names` | `expected_components` property |
| `output_names` | `intermediate_outputs` property |
In this example: `prompt` is the only input, there are no model components, and outputs include `out_prompt`, `old_prompt`, and `doc`.
Now let's look at the `params` dict:
**`prompt`** is an input parameter. It has `display: "textarea"` which renders as a text input box, `label: "Prompt"` shown in the UI, and `default: ""` so it starts empty. The `type: "string"` field is important in Mellon because it determines which nodes can connect together - only matching types can be linked with "noodles".
**`out_prompt`** is the expanded prompt output. The `out_` prefix was automatically added because the input and output share the same name (`prompt`), avoiding naming conflicts in the config. It has `display: "output"` which renders as an output socket.
**`old_prompt`** has `type: "custom"` because we didn't specify metadata. This renders as a simple dot in the UI. Since we don't actually want to expose this in the UI, we can remove it.
**`doc`** is the documentation output, automatically added to all custom blocks.
### Making Adjustments
For the Gemini Prompt Expander, we don't need `old_prompt` in the UI. Remove it from both `params` and `output_names`:
```json
{
"label": "Gemini Prompt Expander",
"default_repo": "",
"default_dtype": "",
"node_params": {
"custom": {
"params": {
"prompt": {
"label": "Prompt",
"type": "string",
"display": "textarea",
"default": ""
},
"out_prompt": {
"label": "Prompt",
"type": "string",
"display": "output"
},
"doc": {
"label": "Doc",
"type": "string",
"display": "output"
}
},
"input_names": ["prompt"],
"model_input_names": [],
"output_names": ["out_prompt", "doc"],
"block_name": "custom",
"node_type": "custom"
}
}
}
```
See the final config at [YiYiXu/gemini-prompt-expander](https://huggingface.co/YiYiXu/gemini-prompt-expander).
## Use in Mellon
1. Start Mellon (see [Mellon installation guide](https://github.com/cubiq/Mellon))
2. In Mellon:
- Drag a **Dynamic Block Node** from the ModularDiffusers section
- Enter your `repo_id` (e.g., `YiYiXu/gemini-prompt-expander`)
- Click **Load Custom Block**
- The node will transform to show your block's inputs and outputs

View File

@@ -33,9 +33,14 @@ The Modular Diffusers docs are organized as shown below.
- [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together.
- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
- [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`].
- [Building Custom Blocks](./custom_blocks) shows you how to create your own custom blocks and share them on the Hub.
## ModularPipeline
- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
## Mellon Integration
- [Using Custom Blocks with Mellon](./mellon) shows you how to make your custom blocks work with [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows.

View File

@@ -324,6 +324,7 @@ class ComponentsManager:
"has_hook",
"execution_device",
"ip_adapter",
"quantization",
]
def __init__(self):
@@ -356,7 +357,9 @@ class ComponentsManager:
ids_by_name.add(component_id)
else:
ids_by_name = set(components.keys())
if collection:
if collection and collection not in self.collections:
return set()
elif collection and collection in self.collections:
ids_by_collection = set()
for component_id, component in components.items():
if component_id in self.collections[collection]:
@@ -423,7 +426,8 @@ class ComponentsManager:
# add component to components manager
self.components[component_id] = component
self.added_time[component_id] = time.time()
if is_new_component:
self.added_time[component_id] = time.time()
if collection:
if collection not in self.collections:
@@ -760,7 +764,6 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
# YiYi TODO: (1) add quantization info
def get_model_info(
self,
component_id: str,
@@ -836,6 +839,17 @@ class ComponentsManager:
if scales:
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
# Check for quantization
hf_quantizer = getattr(component, "hf_quantizer", None)
if hf_quantizer is not None:
quant_config = hf_quantizer.quantization_config
if hasattr(quant_config, "to_diff_dict"):
info["quantization"] = quant_config.to_diff_dict()
else:
info["quantization"] = quant_config.to_dict()
else:
info["quantization"] = None
# If fields specified, filter info
if fields is not None:
return {k: v for k, v in info.items() if k in fields}
@@ -966,12 +980,16 @@ class ComponentsManager:
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
for name in self.components:
info = self.get_model_info(name)
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
if info is not None and (
info.get("adapters") is not None or info.get("ip_adapter") or info.get("quantization")
):
output += f"\n{name}:\n"
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
output += " IP-Adapter: Enabled\n"
if info.get("quantization"):
output += f" Quantization: {info['quantization']}\n"
return output

File diff suppressed because it is too large Load Diff

View File

@@ -520,6 +520,7 @@ class InputParam:
required: bool = False
description: str = ""
kwargs_type: str = None
metadata: Dict[str, Any] = None
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@@ -553,6 +554,7 @@ class OutputParam:
type_hint: Any = None
description: str = ""
kwargs_type: str = None
metadata: Dict[str, Any] = None
def __repr__(self):
return (