Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
8da128067c Fix syntax error in quantization configuration 2026-02-04 10:10:14 +05:30
Sayak Paul
1b8fc6c589 [modular] change the template modular pipeline card (#13072)
* start better template for modular pipeline card.

* simplify structure.

* refine.

* style.

* up

* add tests
2026-02-04 10:09:10 +05:30
YiYi Xu
6d4fc6baa0 [Modular] mellon doc etc (#13051)
* add metadata field to input/output param

* refactor mellonparam: move the template outside, add metaclass, define some generic template for custom node

* add from_custom_block

* style

* up up fix

* add mellon guide

* add to toctree

* style

* add mellon_types

* style

* mellon_type -> inpnt_types + output_types

* update doc

* add quant info to components manager

* fix more

* up up

* fix components manager

* update custom block guide

* update

* style

* add a warn for mellon and add new guides to overview

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/mellon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* more update on custom block guide

* Update docs/source/en/modular_diffusers/mellon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* a few mamual

* apply suggestion: turn into bullets

* support define mellon meta with MellonParam directly, and update doc

* add the video

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
2026-02-03 13:38:57 -10:00
YiYi Xu
ebd06f9b11 [Modular] loader related (#13025)
* tag loader_id from Automodel

* style

* load_components by default only load components that are not already loaded

* by default, skip loading the componeneets does not have the repo id
2026-02-03 05:34:33 -10:00
17 changed files with 1585 additions and 1435 deletions

View File

@@ -114,6 +114,8 @@
title: Guiders
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
- local: modular_diffusers/mellon
title: Using Custom Blocks with Mellon
title: Modular Diffusers
- isExpanded: false
sections:

View File

@@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License.
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
> [!TIP]
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom blocks.
## Project Structure
@@ -31,18 +31,58 @@ Your custom block project should use the following structure:
- `block.py` contains the custom block implementation
- `modular_config.json` contains the metadata needed to load the block
## Example: Florence 2 Inpainting Block
## Quick Start with Template
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
The fastest way to create a custom block is to start from our template. The template provides a pre-configured project structure with `block.py` and `modular_config.json` files, plus commented examples showing how to define components, inputs, outputs, and the `__call__` method—so you can focus on your custom logic instead of boilerplate setup.
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
### Download the template
```py
```python
from diffusers import ModularPipelineBlocks
model_id = "diffusers/custom-block-template"
local_dir = model_id.split("/")[-1]
blocks = ModularPipelineBlocks.from_pretrained(
model_id,
trust_remote_code=True,
local_dir=local_dir
)
```
This saves the template files to `custom-block-template/` locally or you could use `local_dir` to save to a specific location.
### Edit locally
Open `block.py` and implement your custom block. The template includes commented examples showing how to define each property. See the [Florence-2 example](#example-florence-2-image-annotator) below for a complete implementation.
### Test your block
```python
from diffusers import ModularPipelineBlocks
blocks = ModularPipelineBlocks.from_pretrained(local_dir, trust_remote_code=True)
pipeline = blocks.init_pipeline()
output = pipeline(...) # your inputs here
```
### Upload to the Hub
```python
pipeline.save_pretrained(local_dir, repo_id="your-username/your-block-name", push_to_hub=True)
```
## Example: Florence-2 Image Annotator
This example creates a custom block with [Florence-2](https://huggingface.co/docs/transformers/model_doc/florence2) to process an input image and generate a mask for inpainting.
### Define components
Define the components the block needs, `Florence2ForConditionalGeneration` and its processor. When defining components, specify the `name` (how you'll access it in code), `type_hint` (the model class), and `pretrained_model_name_or_path` (where to load weights from).
```python
# Inside block.py
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ComponentSpec,
)
from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec
from transformers import AutoProcessor, Florence2ForConditionalGeneration
@@ -64,40 +104,19 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
]
```
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
### Define inputs and outputs
```py
Inputs include the image, annotation task, and prompt. Outputs include the generated mask and annotations.
```python
from typing import List, Union
from PIL import Image, ImageDraw
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
from PIL import Image
from diffusers.modular_pipelines import InputParam, OutputParam
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
# ... expected_components from above ...
@property
def inputs(self) -> List[InputParam]:
@@ -110,51 +129,21 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
type_hint=str,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
description="Annotation task to perform (e.g., <OD>, <CAPTION>, <REFERRING_EXPRESSION_SEGMENTATION>)",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
type_hint=str,
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
description="Prompt to provide context for the annotation task",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
description="Output type: 'mask_image', 'mask_overlay', or 'bounding_box'",
),
]
@@ -163,225 +152,45 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
type_hint=Image.Image,
description="Inpainting mask for the input image",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
description="Raw annotation predictions",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
type_hint=Image.Image,
description="Annotated image",
),
]
```
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
### Implement the `__call__` method
```py
from typing import List, Union
from PIL import Image, ImageDraw
The `__call__` method contains the block's logic. Access inputs via `block_state`, run your computation, and set outputs back to `block_state`.
```python
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from transformers import AutoProcessor, Florence2ForConditionalGeneration
from diffusers.modular_pipelines import PipelineState
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="image_annotator",
type_hint=Florence2ForConditionalGeneration,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
ComponentSpec(
name="image_annotator_processor",
type_hint=AutoProcessor,
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, List[Image.Image]],
required=True,
description="Image(s) to annotate",
),
InputParam(
"annotation_task",
type_hint=Union[str, List[str]],
required=True,
default="<REFERRING_EXPRESSION_SEGMENTATION>",
description="""Annotation Task to perform on the image.
Supported Tasks:
<OD>
<REFERRING_EXPRESSION_SEGMENTATION>
<CAPTION>
<DETAILED_CAPTION>
<MORE_DETAILED_CAPTION>
<DENSE_REGION_CAPTION>
<CAPTION_TO_PHRASE_GROUNDING>
<OPEN_VOCABULARY_DETECTION>
""",
),
InputParam(
"annotation_prompt",
type_hint=Union[str, List[str]],
required=True,
description="""Annotation Prompt to provide more context to the task.
Can be used to detect or segment out specific elements in the image
""",
),
InputParam(
"annotation_output_type",
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
- mask overlayed on the original image
bounding_box:
- bounding boxes drawn on the original image
""",
),
InputParam(
"annotation_overlay",
type_hint=bool,
required=True,
default=False,
description="",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"mask_image",
type_hint=Image,
description="Inpainting Mask for input Image(s)",
),
OutputParam(
"annotations",
type_hint=dict,
description="Annotations Predictions for input Image(s)",
),
OutputParam(
"image",
type_hint=Image,
description="Annotated input Image(s)",
),
]
def get_annotations(self, components, images, prompts, task):
task_prompts = [task + prompt for prompt in prompts]
inputs = components.image_annotator_processor(
text=task_prompts, images=images, return_tensors="pt"
).to(components.image_annotator.device, components.image_annotator.dtype)
generated_ids = components.image_annotator.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
annotations = components.image_annotator_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
outputs = []
for image, annotation in zip(images, annotations):
outputs.append(
components.image_annotator_processor.post_process_generation(
annotation, task=task, image_size=(image.width, image.height)
)
)
return outputs
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
masks = []
for image, annotation in zip(images, annotations):
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_image)
for _, _annotation in annotation.items():
if "polygons" in _annotation:
for polygon in _annotation["polygons"]:
polygon = np.array(polygon).reshape(-1, 2)
if len(polygon) < 3:
continue
polygon = polygon.reshape(-1).tolist()
draw.polygon(polygon, fill=fill)
elif "bbox" in _annotation:
bbox = _annotation["bbox"]
draw.rectangle(bbox, fill="white")
masks.append(mask_image)
return masks
def prepare_bounding_boxes(self, images, annotations):
outputs = []
for image, annotation in zip(images, annotations):
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
for _, _annotation in annotation.items():
bbox = _annotation["bbox"]
label = _annotation["label"]
draw.rectangle(bbox, outline="red", width=3)
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
outputs.append(image_copy)
return outputs
def prepare_inputs(self, images, prompts):
prompts = prompts or ""
if isinstance(images, Image.Image):
images = [images]
if isinstance(prompts, str):
prompts = [prompts]
if len(images) != len(prompts):
raise ValueError("Number of images and annotation prompts must match.")
return images, prompts
# ... expected_components, inputs, intermediate_outputs from above ...
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
images, annotation_task_prompt = self.prepare_inputs(
block_state.image, block_state.annotation_prompt
)
task = block_state.annotation_task
fill = block_state.fill
annotations = self.get_annotations(
components, images, annotation_task_prompt, task
)
@@ -400,67 +209,69 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
# Helper methods for mask/bounding box generation...
```
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
<hfoptions id="share">
<hfoption id="hf CLI">
```shell
# In the folder with the `block.py` file, run:
diffusers-cli custom_block
```
Then upload the block to the Hub:
```shell
hf upload <your repo id> . .
```
</hfoption>
<hfoption id="push_to_hub">
```py
from block import Florence2ImageAnnotatorBlock
block = Florence2ImageAnnotatorBlock()
block.push_to_hub("<your repo id>")
```
</hfoption>
</hfoptions>
> [!TIP]
> See the complete implementation at [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator).
## Using Custom Blocks
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
Load a custom block with [`~ModularPipeline.from_pretrained`] and set `trust_remote_code=True`.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers import ModularPipeline
from diffusers.utils import load_image
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
# Load the Florence-2 annotator pipeline
image_annotator = ModularPipeline.from_pretrained(
"diffusers/Florence2-image-Annotator",
trust_remote_code=True
)
my_blocks = INPAINT_BLOCKS.copy()
# insert the annotation block before the image encoding step
my_blocks.insert("image_annotator", image_annotator_block, 1)
# Check the docstring to see inputs/outputs
print(image_annotator.blocks.doc)
```
# Create our initial set of inpainting blocks
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
Use the block to generate a mask:
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
pipe = blocks.init_pipeline(repo_id)
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
```python
image_annotator.load_components(torch_dtype=torch.bfloat16)
image_annotator.to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg")
image = image.resize((1024, 1024))
prompt = ["A red car"]
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
annotation_prompt = ["the car"]
mask_image = image_annotator_node(
prompt=prompt,
image=image,
annotation_task=annotation_task,
annotation_prompt=annotation_prompt,
annotation_output_type="mask_image",
).images
mask_image[0].save("car-mask.png")
```
Compose it with other blocks to create a new pipeline:
```python
# Get the annotator block
annotator_block = image_annotator.blocks
# Get an inpainting workflow and insert the annotator at the beginning
inpaint_blocks = ModularPipeline.from_pretrained("Qwen/Qwen-Image").blocks.get_workflow("inpainting")
inpaint_blocks.sub_blocks.insert("image_annotator", annotator_block, 0)
# Initialize the combined pipeline
pipe = inpaint_blocks.init_pipeline()
pipe.load_components(torch_dtype=torch.float16, device="cuda")
# Now the pipeline automatically generates masks from prompts
output = pipe(
prompt=prompt,
image=image,
@@ -475,18 +286,50 @@ output = pipe(
output[0].save("florence-inpainting.png")
```
## Editing Custom Blocks
## Editing custom blocks
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
Edit custom blocks by downloading it locally. This is the same workflow as the [Quick Start with Template](#quick-start-with-template), but starting from an existing block instead of the template.
```py
import torch
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
Use the `local_dir` argument to download a custom block to a specific folder:
# Fetch the Florence2 image annotator block that will create our mask
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
```python
from diffusers import ModularPipelineBlocks
# Download to a local folder for editing
annotator_block = ModularPipelineBlocks.from_pretrained(
"diffusers/Florence2-image-Annotator",
trust_remote_code=True,
local_dir="./my-florence-block"
)
```
Any changes made to the block files in this folder will be reflected when you load the block again.
Any changes made to the block files in this folder will be reflected when you load the block again. When you're ready to share your changes, upload to a new repository:
```python
pipeline = annotator_block.init_pipeline()
pipeline.save_pretrained("./my-florence-block", repo_id="your-username/my-custom-florence", push_to_hub=True)
```
## Next Steps
<hfoptions id="next">
<hfoption id="Learn block types">
This guide covered creating a single custom block. Learn how to compose multiple blocks together:
- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to execute in sequence
- [ConditionalPipelineBlocks](./auto_pipeline_blocks): Create conditional blocks that select different execution paths
- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks): Define an iterative workflows like the denoising loop
</hfoption>
<hfoption id="Use in Mellon">
Make your custom block work with Mellon's visual interface. See the [Mellon Custom Blocks](./mellon) guide.
</hfoption>
<hfoption id="Explore existing blocks">
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
</hfoption>
</hfoptions>

View File

@@ -0,0 +1,270 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
## Using Custom Blocks with Mellon
[Mellon](https://github.com/cubiq/Mellon) is a visual workflow interface that integrates with Modular Diffusers and is designed for node-based workflows.
> [!WARNING]
> Mellon is in early development and not ready for production use yet. Consider this a sneak peek of how the integration works!
Custom blocks work in Mellon out of the box - just need to add a `mellon_pipeline_config.json` to your repository. This config file tells Mellon how to render your block's parameters as UI components.
Here's what it looks like in action with the [Gemini Prompt Expander](https://huggingface.co/diffusers/gemini-prompt-expander-mellon) block:
![Mellon custom block demo](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/modular_demo_dynamic.gif)
To use a modular diffusers custom block in Mellon:
1. Drag a **Dynamic Block Node** from the ModularDiffusers section
2. Enter the `repo_id` (e.g., `diffusers/gemini-prompt-expander-mellon`)
3. Click **Load Custom Block**
4. The node transforms to show your block's inputs and outputs
Now let's walk through how to create this config for your own custom block.
## Steps to create a Mellon config
1. **Specify Mellon types for your parameters** - Each `InputParam`/`OutputParam` needs a type that tells Mellon what UI component to render (e.g., `"textbox"`, `"dropdown"`, `"image"`).
2. **Generate `mellon_pipeline_config.json`** - Use our utility to generate a config template and push it to your Hub repository.
3. **(Optional) Manually adjust the config** - Fine-tune the generated config for your specific needs.
## Specify Mellon types for parameters
Mellon types determine how each parameter renders in the UI. If you don't specify a type for a parameter, it will default to `"custom"`, which renders as a simple connection dot. You can always adjust this later in the generated config.
| Type | Input/Output | Description |
|------|--------------|-------------|
| `image` | Both | Image (PIL Image) |
| `video` | Both | Video |
| `text` | Both | Text display |
| `textbox` | Input | Text input |
| `dropdown` | Input | Dropdown selection menu |
| `slider` | Input | Slider for numeric values |
| `number` | Input | Numeric input |
| `checkbox` | Input | Boolean toggle |
For parameters that need more configuration (like dropdowns with options, or sliders with min/max values), pass a `MellonParam` instance directly instead of a string. You can use one of the class methods below, or create a fully custom one with `MellonParam(name, label, type, ...)`.
| Method | Description |
|--------|-------------|
| `MellonParam.Input.image(name)` | Image input |
| `MellonParam.Input.textbox(name, default)` | Text input as textarea |
| `MellonParam.Input.dropdown(name, options, default)` | Dropdown selection |
| `MellonParam.Input.slider(name, default, min, max, step)` | Slider for numeric values |
| `MellonParam.Input.number(name, default, min, max, step)` | Numeric input (no slider) |
| `MellonParam.Input.seed(name, default)` | Seed input with randomize button |
| `MellonParam.Input.checkbox(name, default)` | Boolean checkbox |
| `MellonParam.Input.model(name)` | Model input for diffusers components |
| `MellonParam.Output.image(name)` | Image output |
| `MellonParam.Output.video(name)` | Video output |
| `MellonParam.Output.text(name)` | Text output |
| `MellonParam.Output.model(name)` | Model output for diffusers components |
Choose one of the methods below to specify a Mellon type.
### Using `metadata` in block definitions
If you're defining a custom block from scratch, add `metadata={"mellon": "<type>"}` directly to your `InputParam` and `OutputParam` definitions. If you're editing an existing custom block from the Hub, see [Editing custom blocks](./custom_blocks#editing-custom-blocks) for how to download it locally.
```python
class GeminiPromptExpander(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
metadata={"mellon": "textbox"}, # Text input
)
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt",
type_hint=str,
description="Expanded prompt by the LLM",
metadata={"mellon": "text"}, # Text output
),
OutputParam(
"old_prompt",
type_hint=str,
description="Old prompt provided by the user",
# No metadata - we don't want to render this in UI
)
]
```
For full control over UI configuration, pass a `MellonParam` instance directly:
```python
from diffusers.modular_pipelines.mellon_node_utils import MellonParam
InputParam(
"mode",
type_hint=str,
default="balanced",
metadata={"mellon": MellonParam.Input.dropdown("mode", options=["fast", "balanced", "quality"])},
)
```
### Using `input_types` and `output_types` when Generating Config
If you're working with an existing pipeline or prefer to keep your block definitions clean, specify types when generating the config using the `input_types/output_types` argument:
```python
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig
mellon_config = MellonPipelineConfig.from_custom_block(
blocks,
input_types={"prompt": "textbox"},
output_types={"prompt": "text"}
)
```
> [!NOTE]
> When both `metadata` and `input_types`/`output_types` are specified, the arguments overrides `metadata`.
## Generate and push the Mellon config
After adding metadata to your block, generate the default Mellon configuration template and push it to the Hub:
```python
from diffusers import ModularPipelineBlocks
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig
# load your custom blocks from your local dir
blocks = ModularPipelineBlocks.from_pretrained("/path/local/folder", trust_remote_code=True)
# Generate the default config template
mellon_config = MellonPipelineConfig.from_custom_block(blocks)
# push the default template to `repo_id`, you will need to pass the same local folder path so that it will save the config locally first
mellon_config.save(
local_dir="/path/local/folder",
repo_id= repo_id,
push_to_hub=True
)
```
This creates a `mellon_pipeline_config.json` file in your repository.
## Review and adjust the config
The generated template is a starting point - you may want to adjust it for your needs. Let's walk through the generated config for the Gemini Prompt Expander:
```json
{
"label": "Gemini Prompt Expander",
"default_repo": "",
"default_dtype": "",
"node_params": {
"custom": {
"params": {
"prompt": {
"label": "Prompt",
"type": "string",
"display": "textarea",
"default": ""
},
"out_prompt": {
"label": "Prompt",
"type": "string",
"display": "output"
},
"old_prompt": {
"label": "Old Prompt",
"type": "custom",
"display": "output"
},
"doc": {
"label": "Doc",
"type": "string",
"display": "output"
}
},
"input_names": ["prompt"],
"model_input_names": [],
"output_names": ["out_prompt", "old_prompt", "doc"],
"block_name": "custom",
"node_type": "custom"
}
}
}
```
### Understanding the Structure
The `params` dict defines how each UI element renders. The `input_names`, `model_input_names`, and `output_names` lists map these UI elements to the underlying [`ModularPipelineBlocks`]'s I/O interface:
| Mellon Config | ModularPipelineBlocks |
|---------------|----------------------|
| `input_names` | `inputs` property |
| `model_input_names` | `expected_components` property |
| `output_names` | `intermediate_outputs` property |
In this example: `prompt` is the only input. There are no model components, and outputs include `out_prompt`, `old_prompt`, and `doc`.
Now let's look at the `params` dict:
- **`prompt`**: An input parameter with `display: "textarea"` (renders as a text input box), `label: "Prompt"` (shown in the UI), and `default: ""` (starts empty). The `type: "string"` field is important in Mellon because it determines which nodes can connect together - only matching types can be linked with "noodles".
- **`out_prompt`**: The expanded prompt output. The `out_` prefix was automatically added because the input and output share the same name (`prompt`), avoiding naming conflicts in the config. It has `display: "output"` which renders as an output socket.
- **`old_prompt`**: Has `type: "custom"` because we didn't specify metadata. This renders as a simple dot in the UI. Since we don't actually want to expose this in the UI, we can remove it.
- **`doc`**: The documentation output, automatically added to all custom blocks.
### Making Adjustments
Remove `old_prompt` from both `params` and `output_names` because you won't need to use it.
```json
{
"label": "Gemini Prompt Expander",
"default_repo": "",
"default_dtype": "",
"node_params": {
"custom": {
"params": {
"prompt": {
"label": "Prompt",
"type": "string",
"display": "textarea",
"default": ""
},
"out_prompt": {
"label": "Prompt",
"type": "string",
"display": "output"
},
"doc": {
"label": "Doc",
"type": "string",
"display": "output"
}
},
"input_names": ["prompt"],
"model_input_names": [],
"output_names": ["out_prompt", "doc"],
"block_name": "custom",
"node_type": "custom"
}
}
}
```
See the final config at [diffusers/gemini-prompt-expander-mellon](https://huggingface.co/diffusers/gemini-prompt-expander-mellon).

View File

@@ -33,9 +33,14 @@ The Modular Diffusers docs are organized as shown below.
- [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together.
- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
- [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`].
- [Building Custom Blocks](./custom_blocks) shows you how to create your own custom blocks and share them on the Hub.
## ModularPipeline
- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
## Mellon Integration
- [Using Custom Blocks with Mellon](./mellon) shows you how to make your custom blocks work with [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows.

View File

@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
from torchao.quantization import Int4WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",

View File

@@ -18,7 +18,7 @@ from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
kwargs = {**load_config_kwargs, **kwargs}
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
load_id = "|".join("null" if p is None else p for p in parts)
model._diffusers_load_id = load_id
return model

View File

@@ -324,6 +324,7 @@ class ComponentsManager:
"has_hook",
"execution_device",
"ip_adapter",
"quantization",
]
def __init__(self):
@@ -356,7 +357,9 @@ class ComponentsManager:
ids_by_name.add(component_id)
else:
ids_by_name = set(components.keys())
if collection:
if collection and collection not in self.collections:
return set()
elif collection and collection in self.collections:
ids_by_collection = set()
for component_id, component in components.items():
if component_id in self.collections[collection]:
@@ -423,7 +426,8 @@ class ComponentsManager:
# add component to components manager
self.components[component_id] = component
self.added_time[component_id] = time.time()
if is_new_component:
self.added_time[component_id] = time.time()
if collection:
if collection not in self.collections:
@@ -760,7 +764,6 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
# YiYi TODO: (1) add quantization info
def get_model_info(
self,
component_id: str,
@@ -836,6 +839,17 @@ class ComponentsManager:
if scales:
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
# Check for quantization
hf_quantizer = getattr(component, "hf_quantizer", None)
if hf_quantizer is not None:
quant_config = hf_quantizer.quantization_config
if hasattr(quant_config, "to_diff_dict"):
info["quantization"] = quant_config.to_diff_dict()
else:
info["quantization"] = quant_config.to_dict()
else:
info["quantization"] = None
# If fields specified, filter info
if fields is not None:
return {k: v for k, v in info.items() if k in fields}
@@ -966,12 +980,16 @@ class ComponentsManager:
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
for name in self.components:
info = self.get_model_info(name)
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
if info is not None and (
info.get("adapters") is not None or info.get("ip_adapter") or info.get("quantization")
):
output += f"\n{name}:\n"
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
output += " IP-Adapter: Enabled\n"
if info.get("quantization"):
output += f" Quantization: {info['quantization']}\n"
return output

File diff suppressed because it is too large Load Diff

View File

@@ -34,6 +34,7 @@ from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
MODULAR_MODEL_CARD_TEMPLATE,
ComponentSpec,
ConfigSpec,
InputParam,
@@ -41,6 +42,7 @@ from .modular_pipeline_utils import (
OutputParam,
format_components,
format_configs,
generate_modular_model_card_content,
make_doc_string,
)
@@ -1753,9 +1755,19 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# Generate modular pipeline card content
card_content = generate_modular_model_card_content(self.blocks)
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card = load_or_create_model_card(
repo_id,
token=token,
is_pipeline=True,
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
is_modular=True,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])
model_card.save(os.path.join(save_directory, "README.md"))
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
@@ -2143,6 +2155,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained"
and self._component_specs[name].pretrained_model_name_or_path is not None
and getattr(self, name, None) is None
]
elif isinstance(names, str):
names = [names]

View File

@@ -15,7 +15,7 @@
import inspect
import re
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Type, Union
import PIL.Image
@@ -23,7 +23,7 @@ import torch
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
if is_torch_available():
@@ -31,6 +31,30 @@ if is_torch_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Template for modular pipeline model card description with placeholders
MODULAR_MODEL_CARD_TEMPLATE = """{model_description}
## Example Usage
[TODO]
## Pipeline Architecture
This modular pipeline is composed of the following blocks:
{blocks_description} {trigger_inputs_section}
## Model Components
{components_description} {configs_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
"""
class InsertableDict(OrderedDict):
def insert(self, key, value, index):
@@ -186,7 +210,7 @@ class ComponentSpec:
"""
Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
return DIFFUSERS_LOAD_ID_FIELDS.copy()
@property
def load_id(self) -> str:
@@ -198,7 +222,7 @@ class ComponentSpec:
return "null"
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
return "|".join(parts)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
@@ -520,6 +544,7 @@ class InputParam:
required: bool = False
description: str = ""
kwargs_type: str = None
metadata: Dict[str, Any] = None
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@@ -553,6 +578,7 @@ class OutputParam:
type_hint: Any = None
description: str = ""
kwargs_type: str = None
metadata: Dict[str, Any] = None
def __repr__(self):
return (
@@ -914,3 +940,178 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)
return output
def generate_modular_model_card_content(blocks) -> Dict[str, Any]:
"""
Generate model card content for a modular pipeline.
This function creates a comprehensive model card with descriptions of the pipeline's architecture, components,
configurations, inputs, and outputs.
Args:
blocks: The pipeline's blocks object containing all pipeline specifications
Returns:
Dict[str, Any]: A dictionary containing formatted content sections:
- pipeline_name: Name of the pipeline
- model_description: Overall description with pipeline type
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
blocks_class_name = blocks.__class__.__name__
pipeline_name = blocks_class_name.replace("Blocks", " Pipeline")
description = getattr(blocks, "description", "A modular diffusion pipeline.")
# generate blocks architecture description
blocks_desc_parts = []
sub_blocks = getattr(blocks, "sub_blocks", None) or {}
if sub_blocks:
for i, (name, block) in enumerate(sub_blocks.items()):
block_class = block.__class__.__name__
block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else ""
blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)")
if block_desc:
blocks_desc_parts.append(f" - {block_desc}")
# add sub-blocks if any
if hasattr(block, "sub_blocks") and block.sub_blocks:
for sub_name, sub_block in block.sub_blocks.items():
sub_class = sub_block.__class__.__name__
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
if sub_desc:
blocks_desc_parts.append(f" - {sub_desc}")
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
components = getattr(blocks, "expected_components", [])
if components:
components_str = format_components(components, indent_level=0, add_empty_lines=False)
# remove the "Components:" header since template has its own
components_description = components_str.replace("Components:\n", "").strip()
if components_description:
# Convert to enumerated list
lines = [line.strip() for line in components_description.split("\n") if line.strip()]
enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)]
components_description = "\n".join(enumerated_lines)
else:
components_description = "No specific components required."
else:
components_description = "No specific components required. Components can be loaded dynamically."
configs = getattr(blocks, "expected_configs", [])
configs_section = ""
if configs:
configs_str = format_configs(configs, indent_level=0, add_empty_lines=False)
configs_description = configs_str.replace("Configs:\n", "").strip()
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
inputs = blocks.inputs
outputs = blocks.outputs
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
- **Trigger Inputs**: {trigger_inputs_str}
"""
# generate tags based on pipeline characteristics
tags = ["modular-diffusers", "diffusers"]
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
if any(t in triggers for t in ["image", "image_latents"]):
tags.append("image-to-image")
if any(t in triggers for t in ["control_image", "controlnet_cond"]):
tags.append("controlnet")
if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]):
tags.append("text-to-image")
else:
tags.append("text-to-image")
block_count = len(blocks.sub_blocks)
model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework.
**Pipeline Type**: {blocks_class_name}
**Description**: {description}
This pipeline uses a {block_count}-block architecture that can be customized and extended."""
return {
"pipeline_name": pipeline_name,
"model_description": model_description,
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -23,6 +23,7 @@ from .constants import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
DIFFUSERS_LOAD_ID_FIELDS,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING,

View File

@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
DIFFUSERS_LOAD_ID_FIELDS = [
"pretrained_model_name_or_path",
"subfolder",
"variant",
"revision",
]

View File

@@ -107,6 +107,7 @@ def load_or_create_model_card(
license: Optional[str] = None,
widget: Optional[List[dict]] = None,
inference: Optional[bool] = None,
is_modular: bool = False,
) -> ModelCard:
"""
Loads or creates a model card.
@@ -131,6 +132,8 @@ def load_or_create_model_card(
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
`load_or_create_model_card` from a training script.
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
When True, uses model_description as-is without additional template formatting.
"""
if not is_jinja_available():
raise ValueError(
@@ -159,10 +162,14 @@ def load_or_create_model_card(
)
else:
card_data = ModelCardData()
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)
if is_modular and model_description is not None:
model_card = ModelCard(model_description)
model_card.data = card_data
else:
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)
return model_card

View File

@@ -12,52 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanTransformer3DModel
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -71,118 +76,16 @@ class WanTransformer3DTesterConfig(BaseModelTesterConfig):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Transformer 3D."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Transformer 3D."""
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Transformer 3D."""
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Transformer 3D."""
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -12,57 +12,76 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import WanAnimateTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanAnimateTransformer3DModel
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, ...]:
# Output has fewer channels than input (4 vs 12)
return (4, 21, 16, 16)
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size` below
face_width = 16
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
torch_device
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_image": clip_ref_features,
"pose_hidden_states": pose_latents,
"face_pixel_values": face_pixel_values,
}
@property
def input_shape(self) -> tuple[int, ...]:
return (12, 21, 16, 16)
def input_shape(self):
return (12, 1, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
def output_shape(self):
return (4, 1, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
def prepare_init_args_and_inputs_for_common(self):
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
# contain the vast majority of the parameters in the test model
channel_sizes = {"4": 16, "8": 16, "16": 16}
return {
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -86,158 +105,22 @@ class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
"face_encoder_num_heads": 2,
"inject_face_latents_blocks": 2,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size`
face_width = 16
return {
"hidden_states": randn_tensor(
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states_image": randn_tensor(
(batch_size, clip_seq_len, clip_dim),
generator=self.generator,
device=torch_device,
),
"pose_hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"face_pixel_values": randn_tensor(
(batch_size, 3, inference_segment_length, face_height, face_width),
generator=self.generator,
device=torch_device,
),
}
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Animate Transformer 3D."""
def test_output(self):
# Override test_output because the transformer output is expected to have less channels
# than the main transformer input.
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Animate Transformer 3D."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanAnimateTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Animate Transformer 3D."""
# Override test_output because the transformer output is expected to have less channels than the main transformer
# input.
def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Animate Transformer 3D."""
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
def prepare_init_args_and_inputs_for_common(self):
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -1,198 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import WanVACETransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanVACETransformer3DModel
@property
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 16,
"out_channels": 16,
"text_dim": 32,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 4,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
"vace_layers": [0, 2],
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 32
sequence_length = 12
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
vace_in_channels = 48
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"control_hidden_states": randn_tensor(
(batch_size, vace_in_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan VACE Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanVACETransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}

View File

@@ -8,6 +8,13 @@ import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
generate_modular_model_card_content,
)
from diffusers.utils import logging
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
@@ -335,3 +342,239 @@ class ModularGuiderTesterMixin:
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:
def __init__(self, name, description):
self.__class__.__name__ = name
self.description = description
self.sub_blocks = {}
return MockBlock(name, description)
def create_mock_blocks(
self,
class_name="TestBlocks",
description="Test pipeline description",
num_blocks=2,
components=None,
configs=None,
inputs=None,
outputs=None,
trigger_inputs=None,
model_name=None,
):
class MockBlocks:
def __init__(self):
self.__class__.__name__ = class_name
self.description = description
self.sub_blocks = {}
self.expected_components = components or []
self.expected_configs = configs or []
self.inputs = inputs or []
self.outputs = outputs or []
self.trigger_inputs = trigger_inputs
self.model_name = model_name
blocks = MockBlocks()
# Add mock sub-blocks
for i in range(num_blocks):
block_name = f"block_{i}"
blocks.sub_blocks[block_name] = self.create_mock_block(f"Block{i}", f"Description for block {i}")
return blocks
def test_basic_model_card_content_structure(self):
"""Test that all expected keys are present in the output."""
blocks = self.create_mock_blocks()
content = generate_modular_model_card_content(blocks)
expected_keys = [
"pipeline_name",
"model_description",
"blocks_description",
"components_description",
"configs_section",
"inputs_description",
"outputs_description",
"trigger_inputs_section",
"tags",
]
for key in expected_keys:
assert key in content, f"Expected key '{key}' not found in model card content"
assert isinstance(content["tags"], list), "Tags should be a list"
def test_pipeline_name_generation(self):
"""Test that pipeline name is correctly generated from blocks class name."""
blocks = self.create_mock_blocks(class_name="StableDiffusionBlocks")
content = generate_modular_model_card_content(blocks)
assert content["pipeline_name"] == "StableDiffusion Pipeline"
def test_tags_generation_text_to_image(self):
"""Test that text-to-image tags are correctly generated."""
blocks = self.create_mock_blocks(trigger_inputs=None)
content = generate_modular_model_card_content(blocks)
assert "modular-diffusers" in content["tags"]
assert "diffusers" in content["tags"]
assert "text-to-image" in content["tags"]
def test_tags_generation_with_trigger_inputs(self):
"""Test that tags are correctly generated based on trigger inputs."""
# Test inpainting
blocks = self.create_mock_blocks(trigger_inputs=["mask", "prompt"])
content = generate_modular_model_card_content(blocks)
assert "inpainting" in content["tags"]
# Test image-to-image
blocks = self.create_mock_blocks(trigger_inputs=["image", "prompt"])
content = generate_modular_model_card_content(blocks)
assert "image-to-image" in content["tags"]
# Test controlnet
blocks = self.create_mock_blocks(trigger_inputs=["control_image", "prompt"])
content = generate_modular_model_card_content(blocks)
assert "controlnet" in content["tags"]
def test_tags_with_model_name(self):
"""Test that model name is included in tags when present."""
blocks = self.create_mock_blocks(model_name="stable-diffusion-xl")
content = generate_modular_model_card_content(blocks)
assert "stable-diffusion-xl" in content["tags"]
def test_components_description_formatting(self):
"""Test that components are correctly formatted."""
components = [
ComponentSpec(name="vae", description="VAE component"),
ComponentSpec(name="text_encoder", description="Text encoder component"),
]
blocks = self.create_mock_blocks(components=components)
content = generate_modular_model_card_content(blocks)
assert "vae" in content["components_description"]
assert "text_encoder" in content["components_description"]
# Should be enumerated
assert "1." in content["components_description"]
def test_components_description_empty(self):
"""Test handling of pipelines without components."""
blocks = self.create_mock_blocks(components=None)
content = generate_modular_model_card_content(blocks)
assert "No specific components required" in content["components_description"]
def test_configs_section_with_configs(self):
"""Test that configs section is generated when configs are present."""
configs = [
ConfigSpec(name="num_train_timesteps", default=1000, description="Number of training timesteps"),
]
blocks = self.create_mock_blocks(configs=configs)
content = generate_modular_model_card_content(blocks)
assert "## Configuration Parameters" in content["configs_section"]
def test_configs_section_empty(self):
"""Test that configs section is empty when no configs are present."""
blocks = self.create_mock_blocks(configs=None)
content = generate_modular_model_card_content(blocks)
assert content["configs_section"] == ""
def test_inputs_description_required_and_optional(self):
"""Test that required and optional inputs are correctly formatted."""
inputs = [
InputParam(name="prompt", type_hint=str, required=True, description="The input prompt"),
InputParam(name="num_steps", type_hint=int, required=False, default=50, description="Number of steps"),
]
blocks = self.create_mock_blocks(inputs=inputs)
content = generate_modular_model_card_content(blocks)
assert "**Required:**" in content["inputs_description"]
assert "**Optional:**" in content["inputs_description"]
assert "prompt" in content["inputs_description"]
assert "num_steps" in content["inputs_description"]
assert "default: `50`" in content["inputs_description"]
def test_inputs_description_empty(self):
"""Test handling of pipelines without specific inputs."""
blocks = self.create_mock_blocks(inputs=[])
content = generate_modular_model_card_content(blocks)
assert "No specific inputs defined" in content["inputs_description"]
def test_outputs_description_formatting(self):
"""Test that outputs are correctly formatted."""
outputs = [
OutputParam(name="images", type_hint=torch.Tensor, description="Generated images"),
]
blocks = self.create_mock_blocks(outputs=outputs)
content = generate_modular_model_card_content(blocks)
assert "images" in content["outputs_description"]
assert "Generated images" in content["outputs_description"]
def test_outputs_description_empty(self):
"""Test handling of pipelines without specific outputs."""
blocks = self.create_mock_blocks(outputs=[])
content = generate_modular_model_card_content(blocks)
assert "Standard pipeline outputs" in content["outputs_description"]
def test_trigger_inputs_section_with_triggers(self):
"""Test that trigger inputs section is generated when present."""
blocks = self.create_mock_blocks(trigger_inputs=["mask", "image"])
content = generate_modular_model_card_content(blocks)
assert "### Conditional Execution" in content["trigger_inputs_section"]
assert "`mask`" in content["trigger_inputs_section"]
assert "`image`" in content["trigger_inputs_section"]
def test_trigger_inputs_section_empty(self):
"""Test that trigger inputs section is empty when not present."""
blocks = self.create_mock_blocks(trigger_inputs=None)
content = generate_modular_model_card_content(blocks)
assert content["trigger_inputs_section"] == ""
def test_blocks_description_with_sub_blocks(self):
"""Test that blocks with sub-blocks are correctly described."""
class MockBlockWithSubBlocks:
def __init__(self):
self.__class__.__name__ = "ParentBlock"
self.description = "Parent block"
self.sub_blocks = {
"child1": self.create_child_block("ChildBlock1", "Child 1 description"),
"child2": self.create_child_block("ChildBlock2", "Child 2 description"),
}
def create_child_block(self, name, desc):
class ChildBlock:
def __init__(self):
self.__class__.__name__ = name
self.description = desc
return ChildBlock()
blocks = self.create_mock_blocks()
blocks.sub_blocks["parent"] = MockBlockWithSubBlocks()
content = generate_modular_model_card_content(blocks)
assert "parent" in content["blocks_description"]
assert "child1" in content["blocks_description"]
assert "child2" in content["blocks_description"]
def test_model_description_includes_block_count(self):
"""Test that model description includes the number of blocks."""
blocks = self.create_mock_blocks(num_blocks=5)
content = generate_modular_model_card_content(blocks)
assert "5-block architecture" in content["model_description"]