mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-04 18:05:17 +08:00
Compare commits
2 Commits
main
...
flux-conti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82aaa3665a | ||
|
|
23fdf38fdf |
@@ -114,8 +114,6 @@
|
||||
title: Guiders
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
- local: modular_diffusers/mellon
|
||||
title: Using Custom Blocks with Mellon
|
||||
title: Modular Diffusers
|
||||
- isExpanded: false
|
||||
sections:
|
||||
|
||||
@@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License.
|
||||
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
|
||||
|
||||
> [!TIP]
|
||||
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom blocks.
|
||||
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -31,58 +31,18 @@ Your custom block project should use the following structure:
|
||||
- `block.py` contains the custom block implementation
|
||||
- `modular_config.json` contains the metadata needed to load the block
|
||||
|
||||
## Quick Start with Template
|
||||
## Example: Florence 2 Inpainting Block
|
||||
|
||||
The fastest way to create a custom block is to start from our template. The template provides a pre-configured project structure with `block.py` and `modular_config.json` files, plus commented examples showing how to define components, inputs, outputs, and the `__call__` method—so you can focus on your custom logic instead of boilerplate setup.
|
||||
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
|
||||
|
||||
### Download the template
|
||||
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
|
||||
|
||||
```python
|
||||
from diffusers import ModularPipelineBlocks
|
||||
|
||||
model_id = "diffusers/custom-block-template"
|
||||
local_dir = model_id.split("/")[-1]
|
||||
|
||||
blocks = ModularPipelineBlocks.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
local_dir=local_dir
|
||||
)
|
||||
```
|
||||
|
||||
This saves the template files to `custom-block-template/` locally or you could use `local_dir` to save to a specific location.
|
||||
|
||||
### Edit locally
|
||||
|
||||
Open `block.py` and implement your custom block. The template includes commented examples showing how to define each property. See the [Florence-2 example](#example-florence-2-image-annotator) below for a complete implementation.
|
||||
|
||||
### Test your block
|
||||
|
||||
```python
|
||||
from diffusers import ModularPipelineBlocks
|
||||
|
||||
blocks = ModularPipelineBlocks.from_pretrained(local_dir, trust_remote_code=True)
|
||||
pipeline = blocks.init_pipeline()
|
||||
output = pipeline(...) # your inputs here
|
||||
```
|
||||
|
||||
### Upload to the Hub
|
||||
|
||||
```python
|
||||
pipeline.save_pretrained(local_dir, repo_id="your-username/your-block-name", push_to_hub=True)
|
||||
```
|
||||
|
||||
## Example: Florence-2 Image Annotator
|
||||
|
||||
This example creates a custom block with [Florence-2](https://huggingface.co/docs/transformers/model_doc/florence2) to process an input image and generate a mask for inpainting.
|
||||
|
||||
### Define components
|
||||
|
||||
Define the components the block needs, `Florence2ForConditionalGeneration` and its processor. When defining components, specify the `name` (how you'll access it in code), `type_hint` (the model class), and `pretrained_model_name_or_path` (where to load weights from).
|
||||
|
||||
```python
|
||||
```py
|
||||
# Inside block.py
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec
|
||||
from diffusers.modular_pipelines import (
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
@@ -104,19 +64,40 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
]
|
||||
```
|
||||
|
||||
### Define inputs and outputs
|
||||
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
|
||||
|
||||
Inputs include the image, annotation task, and prompt. Outputs include the generated mask and annotations.
|
||||
|
||||
```python
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
from diffusers.modular_pipelines import InputParam, OutputParam
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
# ... expected_components from above ...
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
@@ -129,21 +110,51 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=str,
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="Annotation task to perform (e.g., <OD>, <CAPTION>, <REFERRING_EXPRESSION_SEGMENTATION>)",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=str,
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="Prompt to provide context for the annotation task",
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="Output type: 'mask_image', 'mask_overlay', or 'bounding_box'",
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -152,45 +163,225 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image.Image,
|
||||
description="Inpainting mask for the input image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Raw annotation predictions",
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image.Image,
|
||||
description="Annotated image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
### Implement the `__call__` method
|
||||
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
|
||||
|
||||
The `__call__` method contains the block's logic. Access inputs via `block_state`, run your computation, and set outputs back to `block_state`.
|
||||
|
||||
```python
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
from diffusers.modular_pipelines import PipelineState
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
# ... expected_components, inputs, intermediate_outputs from above ...
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
def get_annotations(self, components, images, prompts, task):
|
||||
task_prompts = [task + prompt for prompt in prompts]
|
||||
|
||||
inputs = components.image_annotator_processor(
|
||||
text=task_prompts, images=images, return_tensors="pt"
|
||||
).to(components.image_annotator.device, components.image_annotator.dtype)
|
||||
|
||||
generated_ids = components.image_annotator.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
early_stopping=False,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
annotations = components.image_annotator_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=False
|
||||
)
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
outputs.append(
|
||||
components.image_annotator_processor.post_process_generation(
|
||||
annotation, task=task, image_size=(image.width, image.height)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
|
||||
masks = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
|
||||
for _, _annotation in annotation.items():
|
||||
if "polygons" in _annotation:
|
||||
for polygon in _annotation["polygons"]:
|
||||
polygon = np.array(polygon).reshape(-1, 2)
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygon = polygon.reshape(-1).tolist()
|
||||
draw.polygon(polygon, fill=fill)
|
||||
|
||||
elif "bbox" in _annotation:
|
||||
bbox = _annotation["bbox"]
|
||||
draw.rectangle(bbox, fill="white")
|
||||
|
||||
masks.append(mask_image)
|
||||
|
||||
return masks
|
||||
|
||||
def prepare_bounding_boxes(self, images, annotations):
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
image_copy = image.copy()
|
||||
draw = ImageDraw.Draw(image_copy)
|
||||
for _, _annotation in annotation.items():
|
||||
bbox = _annotation["bbox"]
|
||||
label = _annotation["label"]
|
||||
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
|
||||
|
||||
outputs.append(image_copy)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs(self, images, prompts):
|
||||
prompts = prompts or ""
|
||||
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if len(images) != len(prompts):
|
||||
raise ValueError("Number of images and annotation prompts must match.")
|
||||
|
||||
return images, prompts
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
images, annotation_task_prompt = self.prepare_inputs(
|
||||
block_state.image, block_state.annotation_prompt
|
||||
)
|
||||
task = block_state.annotation_task
|
||||
fill = block_state.fill
|
||||
|
||||
|
||||
annotations = self.get_annotations(
|
||||
components, images, annotation_task_prompt, task
|
||||
)
|
||||
@@ -209,69 +400,67 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
# Helper methods for mask/bounding box generation...
|
||||
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> See the complete implementation at [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator).
|
||||
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
|
||||
|
||||
<hfoptions id="share">
|
||||
<hfoption id="hf CLI">
|
||||
|
||||
```shell
|
||||
# In the folder with the `block.py` file, run:
|
||||
diffusers-cli custom_block
|
||||
```
|
||||
|
||||
Then upload the block to the Hub:
|
||||
|
||||
```shell
|
||||
hf upload <your repo id> . .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="push_to_hub">
|
||||
|
||||
```py
|
||||
from block import Florence2ImageAnnotatorBlock
|
||||
block = Florence2ImageAnnotatorBlock()
|
||||
block.push_to_hub("<your repo id>")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Using Custom Blocks
|
||||
|
||||
Load a custom block with [`~ModularPipeline.from_pretrained`] and set `trust_remote_code=True`.
|
||||
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Load the Florence-2 annotator pipeline
|
||||
image_annotator = ModularPipeline.from_pretrained(
|
||||
"diffusers/Florence2-image-Annotator",
|
||||
trust_remote_code=True
|
||||
)
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
|
||||
|
||||
# Check the docstring to see inputs/outputs
|
||||
print(image_annotator.blocks.doc)
|
||||
```
|
||||
my_blocks = INPAINT_BLOCKS.copy()
|
||||
# insert the annotation block before the image encoding step
|
||||
my_blocks.insert("image_annotator", image_annotator_block, 1)
|
||||
|
||||
Use the block to generate a mask:
|
||||
# Create our initial set of inpainting blocks
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
|
||||
|
||||
```python
|
||||
image_annotator.load_components(torch_dtype=torch.bfloat16)
|
||||
image_annotator.to("cuda")
|
||||
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
|
||||
pipe = blocks.init_pipeline(repo_id)
|
||||
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
|
||||
image = image.resize((1024, 1024))
|
||||
|
||||
prompt = ["A red car"]
|
||||
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
|
||||
annotation_prompt = ["the car"]
|
||||
|
||||
mask_image = image_annotator_node(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
annotation_task=annotation_task,
|
||||
annotation_prompt=annotation_prompt,
|
||||
annotation_output_type="mask_image",
|
||||
).images
|
||||
mask_image[0].save("car-mask.png")
|
||||
```
|
||||
|
||||
Compose it with other blocks to create a new pipeline:
|
||||
|
||||
```python
|
||||
# Get the annotator block
|
||||
annotator_block = image_annotator.blocks
|
||||
|
||||
# Get an inpainting workflow and insert the annotator at the beginning
|
||||
inpaint_blocks = ModularPipeline.from_pretrained("Qwen/Qwen-Image").blocks.get_workflow("inpainting")
|
||||
inpaint_blocks.sub_blocks.insert("image_annotator", annotator_block, 0)
|
||||
|
||||
# Initialize the combined pipeline
|
||||
pipe = inpaint_blocks.init_pipeline()
|
||||
pipe.load_components(torch_dtype=torch.float16, device="cuda")
|
||||
|
||||
# Now the pipeline automatically generates masks from prompts
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
@@ -286,50 +475,18 @@ output = pipe(
|
||||
output[0].save("florence-inpainting.png")
|
||||
```
|
||||
|
||||
## Editing custom blocks
|
||||
## Editing Custom Blocks
|
||||
|
||||
Edit custom blocks by downloading it locally. This is the same workflow as the [Quick Start with Template](#quick-start-with-template), but starting from an existing block instead of the template.
|
||||
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
|
||||
|
||||
Use the `local_dir` argument to download a custom block to a specific folder:
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
```python
|
||||
from diffusers import ModularPipelineBlocks
|
||||
|
||||
# Download to a local folder for editing
|
||||
annotator_block = ModularPipelineBlocks.from_pretrained(
|
||||
"diffusers/Florence2-image-Annotator",
|
||||
trust_remote_code=True,
|
||||
local_dir="./my-florence-block"
|
||||
)
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
|
||||
```
|
||||
|
||||
Any changes made to the block files in this folder will be reflected when you load the block again. When you're ready to share your changes, upload to a new repository:
|
||||
|
||||
```python
|
||||
pipeline = annotator_block.init_pipeline()
|
||||
pipeline.save_pretrained("./my-florence-block", repo_id="your-username/my-custom-florence", push_to_hub=True)
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
<hfoptions id="next">
|
||||
<hfoption id="Learn block types">
|
||||
|
||||
This guide covered creating a single custom block. Learn how to compose multiple blocks together:
|
||||
|
||||
- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to execute in sequence
|
||||
- [ConditionalPipelineBlocks](./auto_pipeline_blocks): Create conditional blocks that select different execution paths
|
||||
- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks): Define an iterative workflows like the denoising loop
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Use in Mellon">
|
||||
|
||||
Make your custom block work with Mellon's visual interface. See the [Mellon Custom Blocks](./mellon) guide.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Explore existing blocks">
|
||||
|
||||
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
Any changes made to the block files in this folder will be reflected when you load the block again.
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
|
||||
## Using Custom Blocks with Mellon
|
||||
|
||||
[Mellon](https://github.com/cubiq/Mellon) is a visual workflow interface that integrates with Modular Diffusers and is designed for node-based workflows.
|
||||
|
||||
> [!WARNING]
|
||||
> Mellon is in early development and not ready for production use yet. Consider this a sneak peek of how the integration works!
|
||||
|
||||
|
||||
Custom blocks work in Mellon out of the box - just need to add a `mellon_pipeline_config.json` to your repository. This config file tells Mellon how to render your block's parameters as UI components.
|
||||
|
||||
Here's what it looks like in action with the [Gemini Prompt Expander](https://huggingface.co/diffusers/gemini-prompt-expander-mellon) block:
|
||||
|
||||

|
||||
|
||||
To use a modular diffusers custom block in Mellon:
|
||||
1. Drag a **Dynamic Block Node** from the ModularDiffusers section
|
||||
2. Enter the `repo_id` (e.g., `diffusers/gemini-prompt-expander-mellon`)
|
||||
3. Click **Load Custom Block**
|
||||
4. The node transforms to show your block's inputs and outputs
|
||||
|
||||
Now let's walk through how to create this config for your own custom block.
|
||||
|
||||
## Steps to create a Mellon config
|
||||
|
||||
1. **Specify Mellon types for your parameters** - Each `InputParam`/`OutputParam` needs a type that tells Mellon what UI component to render (e.g., `"textbox"`, `"dropdown"`, `"image"`).
|
||||
2. **Generate `mellon_pipeline_config.json`** - Use our utility to generate a config template and push it to your Hub repository.
|
||||
3. **(Optional) Manually adjust the config** - Fine-tune the generated config for your specific needs.
|
||||
|
||||
## Specify Mellon types for parameters
|
||||
|
||||
Mellon types determine how each parameter renders in the UI. If you don't specify a type for a parameter, it will default to `"custom"`, which renders as a simple connection dot. You can always adjust this later in the generated config.
|
||||
|
||||
|
||||
| Type | Input/Output | Description |
|
||||
|------|--------------|-------------|
|
||||
| `image` | Both | Image (PIL Image) |
|
||||
| `video` | Both | Video |
|
||||
| `text` | Both | Text display |
|
||||
| `textbox` | Input | Text input |
|
||||
| `dropdown` | Input | Dropdown selection menu |
|
||||
| `slider` | Input | Slider for numeric values |
|
||||
| `number` | Input | Numeric input |
|
||||
| `checkbox` | Input | Boolean toggle |
|
||||
|
||||
For parameters that need more configuration (like dropdowns with options, or sliders with min/max values), pass a `MellonParam` instance directly instead of a string. You can use one of the class methods below, or create a fully custom one with `MellonParam(name, label, type, ...)`.
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `MellonParam.Input.image(name)` | Image input |
|
||||
| `MellonParam.Input.textbox(name, default)` | Text input as textarea |
|
||||
| `MellonParam.Input.dropdown(name, options, default)` | Dropdown selection |
|
||||
| `MellonParam.Input.slider(name, default, min, max, step)` | Slider for numeric values |
|
||||
| `MellonParam.Input.number(name, default, min, max, step)` | Numeric input (no slider) |
|
||||
| `MellonParam.Input.seed(name, default)` | Seed input with randomize button |
|
||||
| `MellonParam.Input.checkbox(name, default)` | Boolean checkbox |
|
||||
| `MellonParam.Input.model(name)` | Model input for diffusers components |
|
||||
| `MellonParam.Output.image(name)` | Image output |
|
||||
| `MellonParam.Output.video(name)` | Video output |
|
||||
| `MellonParam.Output.text(name)` | Text output |
|
||||
| `MellonParam.Output.model(name)` | Model output for diffusers components |
|
||||
|
||||
Choose one of the methods below to specify a Mellon type.
|
||||
|
||||
### Using `metadata` in block definitions
|
||||
|
||||
If you're defining a custom block from scratch, add `metadata={"mellon": "<type>"}` directly to your `InputParam` and `OutputParam` definitions. If you're editing an existing custom block from the Hub, see [Editing custom blocks](./custom_blocks#editing-custom-blocks) for how to download it locally.
|
||||
|
||||
```python
|
||||
class GeminiPromptExpander(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
description="Prompt to use",
|
||||
metadata={"mellon": "textbox"}, # Text input
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt",
|
||||
type_hint=str,
|
||||
description="Expanded prompt by the LLM",
|
||||
metadata={"mellon": "text"}, # Text output
|
||||
),
|
||||
OutputParam(
|
||||
"old_prompt",
|
||||
type_hint=str,
|
||||
description="Old prompt provided by the user",
|
||||
# No metadata - we don't want to render this in UI
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
For full control over UI configuration, pass a `MellonParam` instance directly:
|
||||
```python
|
||||
from diffusers.modular_pipelines.mellon_node_utils import MellonParam
|
||||
|
||||
InputParam(
|
||||
"mode",
|
||||
type_hint=str,
|
||||
default="balanced",
|
||||
metadata={"mellon": MellonParam.Input.dropdown("mode", options=["fast", "balanced", "quality"])},
|
||||
)
|
||||
```
|
||||
|
||||
### Using `input_types` and `output_types` when Generating Config
|
||||
|
||||
If you're working with an existing pipeline or prefer to keep your block definitions clean, specify types when generating the config using the `input_types/output_types` argument:
|
||||
```python
|
||||
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig
|
||||
|
||||
mellon_config = MellonPipelineConfig.from_custom_block(
|
||||
blocks,
|
||||
input_types={"prompt": "textbox"},
|
||||
output_types={"prompt": "text"}
|
||||
)
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> When both `metadata` and `input_types`/`output_types` are specified, the arguments overrides `metadata`.
|
||||
|
||||
## Generate and push the Mellon config
|
||||
|
||||
After adding metadata to your block, generate the default Mellon configuration template and push it to the Hub:
|
||||
|
||||
```python
|
||||
from diffusers import ModularPipelineBlocks
|
||||
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig
|
||||
|
||||
# load your custom blocks from your local dir
|
||||
blocks = ModularPipelineBlocks.from_pretrained("/path/local/folder", trust_remote_code=True)
|
||||
|
||||
# Generate the default config template
|
||||
mellon_config = MellonPipelineConfig.from_custom_block(blocks)
|
||||
# push the default template to `repo_id`, you will need to pass the same local folder path so that it will save the config locally first
|
||||
mellon_config.save(
|
||||
local_dir="/path/local/folder",
|
||||
repo_id= repo_id,
|
||||
push_to_hub=True
|
||||
)
|
||||
```
|
||||
|
||||
This creates a `mellon_pipeline_config.json` file in your repository.
|
||||
|
||||
## Review and adjust the config
|
||||
|
||||
The generated template is a starting point - you may want to adjust it for your needs. Let's walk through the generated config for the Gemini Prompt Expander:
|
||||
|
||||
```json
|
||||
{
|
||||
"label": "Gemini Prompt Expander",
|
||||
"default_repo": "",
|
||||
"default_dtype": "",
|
||||
"node_params": {
|
||||
"custom": {
|
||||
"params": {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"display": "textarea",
|
||||
"default": ""
|
||||
},
|
||||
"out_prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"display": "output"
|
||||
},
|
||||
"old_prompt": {
|
||||
"label": "Old Prompt",
|
||||
"type": "custom",
|
||||
"display": "output"
|
||||
},
|
||||
"doc": {
|
||||
"label": "Doc",
|
||||
"type": "string",
|
||||
"display": "output"
|
||||
}
|
||||
},
|
||||
"input_names": ["prompt"],
|
||||
"model_input_names": [],
|
||||
"output_names": ["out_prompt", "old_prompt", "doc"],
|
||||
"block_name": "custom",
|
||||
"node_type": "custom"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Understanding the Structure
|
||||
|
||||
The `params` dict defines how each UI element renders. The `input_names`, `model_input_names`, and `output_names` lists map these UI elements to the underlying [`ModularPipelineBlocks`]'s I/O interface:
|
||||
|
||||
| Mellon Config | ModularPipelineBlocks |
|
||||
|---------------|----------------------|
|
||||
| `input_names` | `inputs` property |
|
||||
| `model_input_names` | `expected_components` property |
|
||||
| `output_names` | `intermediate_outputs` property |
|
||||
|
||||
In this example: `prompt` is the only input. There are no model components, and outputs include `out_prompt`, `old_prompt`, and `doc`.
|
||||
|
||||
Now let's look at the `params` dict:
|
||||
|
||||
- **`prompt`**: An input parameter with `display: "textarea"` (renders as a text input box), `label: "Prompt"` (shown in the UI), and `default: ""` (starts empty). The `type: "string"` field is important in Mellon because it determines which nodes can connect together - only matching types can be linked with "noodles".
|
||||
|
||||
- **`out_prompt`**: The expanded prompt output. The `out_` prefix was automatically added because the input and output share the same name (`prompt`), avoiding naming conflicts in the config. It has `display: "output"` which renders as an output socket.
|
||||
|
||||
- **`old_prompt`**: Has `type: "custom"` because we didn't specify metadata. This renders as a simple dot in the UI. Since we don't actually want to expose this in the UI, we can remove it.
|
||||
|
||||
- **`doc`**: The documentation output, automatically added to all custom blocks.
|
||||
|
||||
### Making Adjustments
|
||||
|
||||
Remove `old_prompt` from both `params` and `output_names` because you won't need to use it.
|
||||
|
||||
```json
|
||||
{
|
||||
"label": "Gemini Prompt Expander",
|
||||
"default_repo": "",
|
||||
"default_dtype": "",
|
||||
"node_params": {
|
||||
"custom": {
|
||||
"params": {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"display": "textarea",
|
||||
"default": ""
|
||||
},
|
||||
"out_prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"display": "output"
|
||||
},
|
||||
"doc": {
|
||||
"label": "Doc",
|
||||
"type": "string",
|
||||
"display": "output"
|
||||
}
|
||||
},
|
||||
"input_names": ["prompt"],
|
||||
"model_input_names": [],
|
||||
"output_names": ["out_prompt", "doc"],
|
||||
"block_name": "custom",
|
||||
"node_type": "custom"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
See the final config at [diffusers/gemini-prompt-expander-mellon](https://huggingface.co/diffusers/gemini-prompt-expander-mellon).
|
||||
@@ -33,14 +33,9 @@ The Modular Diffusers docs are organized as shown below.
|
||||
- [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together.
|
||||
- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
|
||||
- [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`].
|
||||
- [Building Custom Blocks](./custom_blocks) shows you how to create your own custom blocks and share them on the Hub.
|
||||
|
||||
## ModularPipeline
|
||||
|
||||
- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
|
||||
- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
|
||||
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
|
||||
|
||||
## Mellon Integration
|
||||
|
||||
- [Using Custom Blocks with Mellon](./mellon) shows you how to make your custom blocks work with [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows.
|
||||
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
|
||||
@@ -111,57 +111,3 @@ config = TaylorSeerCacheConfig(
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## MagCache
|
||||
|
||||
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
|
||||
|
||||
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
|
||||
|
||||
### Usage
|
||||
|
||||
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
|
||||
|
||||
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
|
||||
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, MagCacheConfig
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# 1. Calibration Step
|
||||
# Run full inference to measure model behavior.
|
||||
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
|
||||
pipe.transformer.enable_cache(calib_config)
|
||||
|
||||
# Run a prompt to trigger calibration
|
||||
pipe("A cat playing chess", num_inference_steps=4)
|
||||
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
|
||||
|
||||
# 2. Inference Step
|
||||
# Apply the specific ratios obtained from calibration for optimized speed.
|
||||
# Note: For Flux models, you can also import defaults:
|
||||
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
|
||||
mag_config = MagCacheConfig(
|
||||
mag_ratios=[1.0, 1.37, 0.97, 0.87],
|
||||
num_inference_steps=4
|
||||
)
|
||||
|
||||
pipe.transformer.enable_cache(mag_config)
|
||||
|
||||
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.
|
||||
|
||||
@@ -168,14 +168,12 @@ else:
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"MagCacheConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_mag_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
]
|
||||
@@ -934,14 +932,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .mag_cache import MagCacheConfig, apply_mag_cache
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
@@ -23,13 +23,7 @@ from ..models.attention_processor import Attention, MochiAttention
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
|
||||
"blocks",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"layers",
|
||||
"visual_transformer_blocks",
|
||||
)
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ class AttentionProcessorMetadata:
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
hidden_states_argument_name: str = "hidden_states"
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
@@ -170,7 +169,7 @@ def _register_attention_processors_metadata():
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_bria import BriaTransformerBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
@@ -185,7 +184,6 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
@@ -333,24 +331,6 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=JointTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=Kandinsky5TransformerDecoderBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
hidden_states_argument_name="visual_embed",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
|
||||
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
|
||||
|
||||
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
|
||||
# Users must explicitly pass these to the config if using Flux.
|
||||
# Reference: https://github.com/Zehong-Ma/MagCache
|
||||
FLUX_MAG_RATIOS = torch.tensor(
|
||||
[1.0]
|
||||
+ [
|
||||
1.21094,
|
||||
1.11719,
|
||||
1.07812,
|
||||
1.0625,
|
||||
1.03906,
|
||||
1.03125,
|
||||
1.03906,
|
||||
1.02344,
|
||||
1.03125,
|
||||
1.02344,
|
||||
0.98047,
|
||||
1.01562,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.0,
|
||||
0.99609,
|
||||
0.99609,
|
||||
0.98047,
|
||||
0.98828,
|
||||
0.96484,
|
||||
0.95703,
|
||||
0.93359,
|
||||
0.89062,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate the source array to the target length using nearest neighbor interpolation.
|
||||
"""
|
||||
src_length = len(src_array)
|
||||
if target_length == 1:
|
||||
return src_array[-1:]
|
||||
|
||||
scale = (src_length - 1) / (target_length - 1)
|
||||
grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32)
|
||||
mapped_indices = torch.round(grid * scale).long()
|
||||
return src_array[mapped_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MagCacheConfig:
|
||||
r"""
|
||||
Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.06`):
|
||||
The threshold for the accumulated error. If the accumulated error is below this threshold, the block
|
||||
computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade
|
||||
quality.
|
||||
max_skip_steps (`int`, defaults to `3`):
|
||||
The maximum number of consecutive steps that can be skipped (K in the paper).
|
||||
retention_ratio (`float`, defaults to `0.2`):
|
||||
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
|
||||
`num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped.
|
||||
num_inference_steps (`int`, defaults to `28`):
|
||||
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
|
||||
mag_ratios (`torch.Tensor`, *optional*):
|
||||
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
|
||||
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
|
||||
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
|
||||
calibrate (`bool`, defaults to `False`):
|
||||
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
|
||||
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
|
||||
models or schedulers.
|
||||
"""
|
||||
|
||||
threshold: float = 0.06
|
||||
max_skip_steps: int = 3
|
||||
retention_ratio: float = 0.2
|
||||
num_inference_steps: int = 28
|
||||
mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None
|
||||
calibrate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# User MUST provide ratios OR enable calibration.
|
||||
if self.mag_ratios is None and not self.calibrate:
|
||||
raise ValueError(
|
||||
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
|
||||
"To get them for your model:\n"
|
||||
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
|
||||
"2. Run inference on your model once.\n"
|
||||
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
|
||||
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
|
||||
)
|
||||
|
||||
if not self.calibrate and self.mag_ratios is not None:
|
||||
if not torch.is_tensor(self.mag_ratios):
|
||||
self.mag_ratios = torch.tensor(self.mag_ratios)
|
||||
|
||||
if len(self.mag_ratios) != self.num_inference_steps:
|
||||
logger.debug(
|
||||
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
|
||||
)
|
||||
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
|
||||
|
||||
|
||||
class MagCacheState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Cache for the residual (output - input) from the *previous* timestep
|
||||
self.previous_residual: torch.Tensor = None
|
||||
|
||||
# State inputs/outputs for the current forward pass
|
||||
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
# MagCache accumulators
|
||||
self.accumulated_ratio: float = 1.0
|
||||
self.accumulated_err: float = 0.0
|
||||
self.accumulated_steps: int = 0
|
||||
|
||||
# Current step counter (timestep index)
|
||||
self.step_index: int = 0
|
||||
|
||||
# Calibration storage
|
||||
self.calibration_ratios: List[float] = []
|
||||
|
||||
def reset(self):
|
||||
self.previous_residual = None
|
||||
self.should_compute = True
|
||||
self.accumulated_ratio = 1.0
|
||||
self.accumulated_err = 0.0
|
||||
self.accumulated_steps = 0
|
||||
self.step_index = 0
|
||||
self.calibration_ratios = []
|
||||
|
||||
|
||||
class MagCacheHeadHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, config: MagCacheConfig):
|
||||
self.state_manager = state_manager
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
state.head_block_input = hidden_states
|
||||
|
||||
should_compute = True
|
||||
|
||||
if self.config.calibrate:
|
||||
# Never skip during calibration
|
||||
should_compute = True
|
||||
else:
|
||||
# MagCache Logic
|
||||
current_step = state.step_index
|
||||
if current_step >= len(self.config.mag_ratios):
|
||||
current_scale = 1.0
|
||||
else:
|
||||
current_scale = self.config.mag_ratios[current_step]
|
||||
|
||||
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
|
||||
|
||||
if current_step >= retention_step:
|
||||
state.accumulated_ratio *= current_scale
|
||||
state.accumulated_steps += 1
|
||||
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
|
||||
|
||||
if (
|
||||
state.previous_residual is not None
|
||||
and state.accumulated_err <= self.config.threshold
|
||||
and state.accumulated_steps <= self.config.max_skip_steps
|
||||
):
|
||||
should_compute = False
|
||||
else:
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
|
||||
state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
logger.debug(f"MagCache: Skipping step {state.step_index}")
|
||||
# Apply MagCache: Output = Input + Previous Residual
|
||||
|
||||
output = hidden_states
|
||||
res = state.previous_residual
|
||||
|
||||
if res.device != output.device:
|
||||
res = res.to(output.device)
|
||||
|
||||
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
|
||||
if res.shape == output.shape:
|
||||
output = output + res
|
||||
elif (
|
||||
output.ndim == 3
|
||||
and res.ndim == 3
|
||||
and output.shape[0] == res.shape[0]
|
||||
and output.shape[2] == res.shape[2]
|
||||
):
|
||||
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
|
||||
diff = output.shape[1] - res.shape[1]
|
||||
if diff > 0:
|
||||
output = output.clone()
|
||||
output[:, diff:, :] = output[:, diff:, :] + res
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = output
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
else:
|
||||
return output
|
||||
|
||||
else:
|
||||
# Compute original forward
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
class MagCacheBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
|
||||
if not state.should_compute:
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Still need to advance step index even if we skip
|
||||
self._advance_step(state)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = hidden_states
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
|
||||
return hidden_states
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Calculate residual for next steps
|
||||
if isinstance(output, tuple):
|
||||
out_hidden = output[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
out_hidden = output
|
||||
|
||||
in_hidden = state.head_block_input
|
||||
|
||||
if in_hidden is None:
|
||||
return output
|
||||
|
||||
# Determine residual
|
||||
if out_hidden.shape == in_hidden.shape:
|
||||
residual = out_hidden - in_hidden
|
||||
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
|
||||
diff = in_hidden.shape[1] - out_hidden.shape[1]
|
||||
if diff == 0:
|
||||
residual = out_hidden - in_hidden
|
||||
else:
|
||||
residual = out_hidden - in_hidden # Fallback to matching tail
|
||||
else:
|
||||
# Fallback for completely mismatched shapes
|
||||
residual = out_hidden
|
||||
|
||||
if self.config.calibrate:
|
||||
self._perform_calibration_step(state, residual)
|
||||
|
||||
state.previous_residual = residual
|
||||
self._advance_step(state)
|
||||
|
||||
return output
|
||||
|
||||
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
|
||||
if state.previous_residual is None:
|
||||
# First step has no previous residual to compare against.
|
||||
# log 1.0 as a neutral starting point.
|
||||
ratio = 1.0
|
||||
else:
|
||||
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
|
||||
# norm(dim=-1) gives magnitude of each token vector
|
||||
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
|
||||
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
|
||||
|
||||
# Avoid division by zero
|
||||
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
|
||||
|
||||
state.calibration_ratios.append(ratio)
|
||||
|
||||
def _advance_step(self, state: MagCacheState):
|
||||
state.step_index += 1
|
||||
if state.step_index >= self.config.num_inference_steps:
|
||||
# End of inference loop
|
||||
if self.config.calibrate:
|
||||
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
|
||||
print(f"{state.calibration_ratios}\n")
|
||||
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
|
||||
|
||||
# Reset state
|
||||
state.step_index = 0
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
state.previous_residual = None
|
||||
state.calibration_ratios = []
|
||||
|
||||
|
||||
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
|
||||
"""
|
||||
Applies MagCache to a given module (typically a Transformer).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply MagCache to.
|
||||
config (`MagCacheConfig`):
|
||||
The configuration for MagCache.
|
||||
"""
|
||||
# Initialize registry on the root module so the Pipeline can set context.
|
||||
HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
state_manager = StateManager(MagCacheState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
if not remaining_blocks:
|
||||
logger.warning("MagCache: No transformer blocks found to apply hooks.")
|
||||
return
|
||||
|
||||
# Handle single-block models
|
||||
if len(remaining_blocks) == 1:
|
||||
name, block = remaining_blocks[0]
|
||||
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
|
||||
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
|
||||
_apply_mag_cache_head_hook(block, state_manager, config)
|
||||
return
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.info(f"MagCache: Applying Head Hook to {head_block_name}")
|
||||
_apply_mag_cache_head_hook(head_block, state_manager, config)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
_apply_mag_cache_block_hook(block, state_manager, config)
|
||||
|
||||
logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}")
|
||||
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
|
||||
|
||||
|
||||
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application (e.g. switching modes)
|
||||
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheHeadHook(state_manager, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_mag_cache_block_hook(
|
||||
block: torch.nn.Module,
|
||||
state_manager: StateManager,
|
||||
config: MagCacheConfig,
|
||||
is_tail: bool = False,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application
|
||||
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheBlockHook(state_manager, is_tail, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
|
||||
@@ -18,7 +18,7 @@ from typing import Optional, Union
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
|
||||
from ..utils import logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
@@ -220,11 +220,4 @@ class AutoModel(ConfigMixin):
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
|
||||
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
|
||||
load_id = "|".join("null" if p is None else p for p in parts)
|
||||
model._diffusers_load_id = load_id
|
||||
|
||||
return model
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@@ -68,12 +68,10 @@ class CacheMixin:
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
@@ -87,8 +85,6 @@ class CacheMixin:
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, MagCacheConfig):
|
||||
apply_mag_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
@@ -103,13 +99,11 @@ class CacheMixin:
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
|
||||
@@ -124,9 +118,6 @@ class CacheMixin:
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
|
||||
@@ -125,9 +125,9 @@ class BriaFiboAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -130,9 +130,9 @@ class FluxAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -324,7 +324,6 @@ class ComponentsManager:
|
||||
"has_hook",
|
||||
"execution_device",
|
||||
"ip_adapter",
|
||||
"quantization",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
@@ -357,9 +356,7 @@ class ComponentsManager:
|
||||
ids_by_name.add(component_id)
|
||||
else:
|
||||
ids_by_name = set(components.keys())
|
||||
if collection and collection not in self.collections:
|
||||
return set()
|
||||
elif collection and collection in self.collections:
|
||||
if collection:
|
||||
ids_by_collection = set()
|
||||
for component_id, component in components.items():
|
||||
if component_id in self.collections[collection]:
|
||||
@@ -426,8 +423,7 @@ class ComponentsManager:
|
||||
|
||||
# add component to components manager
|
||||
self.components[component_id] = component
|
||||
if is_new_component:
|
||||
self.added_time[component_id] = time.time()
|
||||
self.added_time[component_id] = time.time()
|
||||
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
@@ -764,6 +760,7 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
# YiYi TODO: (1) add quantization info
|
||||
def get_model_info(
|
||||
self,
|
||||
component_id: str,
|
||||
@@ -839,17 +836,6 @@ class ComponentsManager:
|
||||
if scales:
|
||||
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
|
||||
|
||||
# Check for quantization
|
||||
hf_quantizer = getattr(component, "hf_quantizer", None)
|
||||
if hf_quantizer is not None:
|
||||
quant_config = hf_quantizer.quantization_config
|
||||
if hasattr(quant_config, "to_diff_dict"):
|
||||
info["quantization"] = quant_config.to_diff_dict()
|
||||
else:
|
||||
info["quantization"] = quant_config.to_dict()
|
||||
else:
|
||||
info["quantization"] = None
|
||||
|
||||
# If fields specified, filter info
|
||||
if fields is not None:
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
@@ -980,16 +966,12 @@ class ComponentsManager:
|
||||
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
|
||||
for name in self.components:
|
||||
info = self.get_model_info(name)
|
||||
if info is not None and (
|
||||
info.get("adapters") is not None or info.get("ip_adapter") or info.get("quantization")
|
||||
):
|
||||
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
||||
output += f"\n{name}:\n"
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
output += " IP-Adapter: Enabled\n"
|
||||
if info.get("quantization"):
|
||||
output += f" Quantization: {info['quantization']}\n"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,7 +34,6 @@ from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline_utils import (
|
||||
MODULAR_MODEL_CARD_TEMPLATE,
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
@@ -42,7 +41,6 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
generate_modular_model_card_content,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
@@ -1755,19 +1753,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Generate modular pipeline card content
|
||||
card_content = generate_modular_model_card_content(self.blocks)
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id,
|
||||
token=token,
|
||||
is_pipeline=True,
|
||||
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
|
||||
is_modular=True,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
|
||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
@@ -2155,8 +2143,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
and self._component_specs[name].pretrained_model_name_or_path is not None
|
||||
and getattr(self, name, None) is None
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import inspect
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -31,30 +31,6 @@ if is_torch_available():
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Template for modular pipeline model card description with placeholders
|
||||
MODULAR_MODEL_CARD_TEMPLATE = """{model_description}
|
||||
|
||||
## Example Usage
|
||||
|
||||
[TODO]
|
||||
|
||||
## Pipeline Architecture
|
||||
|
||||
This modular pipeline is composed of the following blocks:
|
||||
|
||||
{blocks_description} {trigger_inputs_section}
|
||||
|
||||
## Model Components
|
||||
|
||||
{components_description} {configs_section}
|
||||
|
||||
## Input/Output Specification
|
||||
|
||||
### Inputs {inputs_description}
|
||||
|
||||
### Outputs {outputs_description}
|
||||
"""
|
||||
|
||||
|
||||
class InsertableDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
@@ -210,7 +186,7 @@ class ComponentSpec:
|
||||
"""
|
||||
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return DIFFUSERS_LOAD_ID_FIELDS.copy()
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
@@ -222,7 +198,7 @@ class ComponentSpec:
|
||||
return "null"
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(parts)
|
||||
return "|".join(p for p in parts if p)
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
@@ -544,7 +520,6 @@ class InputParam:
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
@@ -578,7 +553,6 @@ class OutputParam:
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@@ -940,178 +914,3 @@ def make_doc_string(
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_modular_model_card_content(blocks) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate model card content for a modular pipeline.
|
||||
|
||||
This function creates a comprehensive model card with descriptions of the pipeline's architecture, components,
|
||||
configurations, inputs, and outputs.
|
||||
|
||||
Args:
|
||||
blocks: The pipeline's blocks object containing all pipeline specifications
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing formatted content sections:
|
||||
- pipeline_name: Name of the pipeline
|
||||
- model_description: Overall description with pipeline type
|
||||
- blocks_description: Detailed architecture of blocks
|
||||
- components_description: List of required components
|
||||
- configs_section: Configuration parameters section
|
||||
- inputs_description: Input parameters specification
|
||||
- outputs_description: Output parameters specification
|
||||
- trigger_inputs_section: Conditional execution information
|
||||
- tags: List of relevant tags for the model card
|
||||
"""
|
||||
blocks_class_name = blocks.__class__.__name__
|
||||
pipeline_name = blocks_class_name.replace("Blocks", " Pipeline")
|
||||
description = getattr(blocks, "description", "A modular diffusion pipeline.")
|
||||
|
||||
# generate blocks architecture description
|
||||
blocks_desc_parts = []
|
||||
sub_blocks = getattr(blocks, "sub_blocks", None) or {}
|
||||
if sub_blocks:
|
||||
for i, (name, block) in enumerate(sub_blocks.items()):
|
||||
block_class = block.__class__.__name__
|
||||
block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else ""
|
||||
blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)")
|
||||
if block_desc:
|
||||
blocks_desc_parts.append(f" - {block_desc}")
|
||||
|
||||
# add sub-blocks if any
|
||||
if hasattr(block, "sub_blocks") and block.sub_blocks:
|
||||
for sub_name, sub_block in block.sub_blocks.items():
|
||||
sub_class = sub_block.__class__.__name__
|
||||
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
|
||||
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
|
||||
if sub_desc:
|
||||
blocks_desc_parts.append(f" - {sub_desc}")
|
||||
|
||||
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
|
||||
|
||||
components = getattr(blocks, "expected_components", [])
|
||||
if components:
|
||||
components_str = format_components(components, indent_level=0, add_empty_lines=False)
|
||||
# remove the "Components:" header since template has its own
|
||||
components_description = components_str.replace("Components:\n", "").strip()
|
||||
if components_description:
|
||||
# Convert to enumerated list
|
||||
lines = [line.strip() for line in components_description.split("\n") if line.strip()]
|
||||
enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)]
|
||||
components_description = "\n".join(enumerated_lines)
|
||||
else:
|
||||
components_description = "No specific components required."
|
||||
else:
|
||||
components_description = "No specific components required. Components can be loaded dynamically."
|
||||
|
||||
configs = getattr(blocks, "expected_configs", [])
|
||||
configs_section = ""
|
||||
if configs:
|
||||
configs_str = format_configs(configs, indent_level=0, add_empty_lines=False)
|
||||
configs_description = configs_str.replace("Configs:\n", "").strip()
|
||||
if configs_description:
|
||||
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
|
||||
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
|
||||
# format inputs as markdown list
|
||||
inputs_parts = []
|
||||
required_inputs = [inp for inp in inputs if inp.required]
|
||||
optional_inputs = [inp for inp in inputs if not inp.required]
|
||||
|
||||
if required_inputs:
|
||||
inputs_parts.append("**Required:**\n")
|
||||
for inp in required_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
|
||||
|
||||
if optional_inputs:
|
||||
if required_inputs:
|
||||
inputs_parts.append("")
|
||||
inputs_parts.append("**Optional:**\n")
|
||||
for inp in optional_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
|
||||
|
||||
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
|
||||
|
||||
# format outputs as markdown list
|
||||
outputs_parts = []
|
||||
for out in outputs:
|
||||
if hasattr(out.type_hint, "__name__"):
|
||||
type_str = out.type_hint.__name__
|
||||
elif out.type_hint is not None:
|
||||
type_str = str(out.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = out.description or "No description provided"
|
||||
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
|
||||
|
||||
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
|
||||
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
### Conditional Execution
|
||||
|
||||
This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
- **Trigger Inputs**: {trigger_inputs_str}
|
||||
"""
|
||||
|
||||
# generate tags based on pipeline characteristics
|
||||
tags = ["modular-diffusers", "diffusers"]
|
||||
|
||||
if hasattr(blocks, "model_name") and blocks.model_name:
|
||||
tags.append(blocks.model_name)
|
||||
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
triggers = blocks.trigger_inputs
|
||||
if any(t in triggers for t in ["mask", "mask_image"]):
|
||||
tags.append("inpainting")
|
||||
if any(t in triggers for t in ["image", "image_latents"]):
|
||||
tags.append("image-to-image")
|
||||
if any(t in triggers for t in ["control_image", "controlnet_cond"]):
|
||||
tags.append("controlnet")
|
||||
if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]):
|
||||
tags.append("text-to-image")
|
||||
else:
|
||||
tags.append("text-to-image")
|
||||
|
||||
block_count = len(blocks.sub_blocks)
|
||||
model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework.
|
||||
|
||||
**Pipeline Type**: {blocks_class_name}
|
||||
|
||||
**Description**: {description}
|
||||
|
||||
This pipeline uses a {block_count}-block architecture that can be customized and extended."""
|
||||
|
||||
return {
|
||||
"pipeline_name": pipeline_name,
|
||||
"model_description": model_description,
|
||||
"blocks_description": blocks_description,
|
||||
"components_description": components_description,
|
||||
"configs_section": configs_section,
|
||||
"inputs_description": inputs_description,
|
||||
"outputs_description": outputs_description,
|
||||
"trigger_inputs_section": trigger_inputs_section,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ from .constants import (
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
DIFFUSERS_LOAD_ID_FIELDS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
|
||||
@@ -73,11 +73,3 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
|
||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
DIFFUSERS_LOAD_ID_FIELDS = [
|
||||
"pretrained_model_name_or_path",
|
||||
"subfolder",
|
||||
"variant",
|
||||
"revision",
|
||||
]
|
||||
|
||||
@@ -227,21 +227,6 @@ class LayerSkipConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MagCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -299,10 +284,6 @@ def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_mag_cache(*args, **kwargs):
|
||||
requires_backends(apply_mag_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -107,7 +107,6 @@ def load_or_create_model_card(
|
||||
license: Optional[str] = None,
|
||||
widget: Optional[List[dict]] = None,
|
||||
inference: Optional[bool] = None,
|
||||
is_modular: bool = False,
|
||||
) -> ModelCard:
|
||||
"""
|
||||
Loads or creates a model card.
|
||||
@@ -132,8 +131,6 @@ def load_or_create_model_card(
|
||||
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
|
||||
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
|
||||
`load_or_create_model_card` from a training script.
|
||||
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
|
||||
When True, uses model_description as-is without additional template formatting.
|
||||
"""
|
||||
if not is_jinja_available():
|
||||
raise ValueError(
|
||||
@@ -162,14 +159,10 @@ def load_or_create_model_card(
|
||||
)
|
||||
else:
|
||||
card_data = ModelCardData()
|
||||
if is_modular and model_description is not None:
|
||||
model_card = ModelCard(model_description)
|
||||
model_card.data = card_data
|
||||
else:
|
||||
component = "pipeline" if is_pipeline else "model"
|
||||
if model_description is None:
|
||||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
|
||||
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||
component = "pipeline" if is_pipeline else "model"
|
||||
if model_description is None:
|
||||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
|
||||
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import MagCacheConfig, apply_mag_cache
|
||||
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Output is double input
|
||||
# This ensures Residual = 2*Input - Input = Input
|
||||
return hidden_states * 2.0
|
||||
|
||||
|
||||
class DummyTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TupleOutputBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Returns a tuple
|
||||
return hidden_states * 2.0, encoder_hidden_states
|
||||
|
||||
|
||||
class TupleTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
# Emulate Flux-like behavior
|
||||
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = output[0]
|
||||
encoder_hidden_states = output[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class MagCacheTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Register standard dummy block
|
||||
TransformerBlockRegistry.register(
|
||||
DummyBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
|
||||
)
|
||||
# Register tuple block (Flux style)
|
||||
TransformerBlockRegistry.register(
|
||||
TupleOutputBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
|
||||
)
|
||||
|
||||
def _set_context(self, model, context_name):
|
||||
"""Helper to set context on all hooks in the model."""
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(context_name)
|
||||
|
||||
def _get_calibration_data(self, model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
|
||||
if hook:
|
||||
return hook.state_manager.get_state().calibration_ratios
|
||||
return []
|
||||
|
||||
def test_mag_cache_validation(self):
|
||||
"""Test that missing mag_ratios raises ValueError."""
|
||||
with self.assertRaises(ValueError):
|
||||
MagCacheConfig(num_inference_steps=10, calibrate=False)
|
||||
|
||||
def test_mag_cache_skipping_logic(self):
|
||||
"""
|
||||
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
|
||||
"""
|
||||
model = DummyTransformer()
|
||||
|
||||
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=0.0, # Enable immediate skipping
|
||||
max_skip_steps=5,
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
|
||||
# HeadInput=10. Output=40. Residual=30.
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
output_t0 = model(input_t0)
|
||||
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
|
||||
|
||||
# Step 1: Input 11.0.
|
||||
# If Skipped: Output = Input(11) + Residual(30) = 41.0
|
||||
# If Computed: Output = 11 * 4 = 44.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_retention(self):
|
||||
"""Test that retention_ratio prevents skipping even if error is low."""
|
||||
model = DummyTransformer()
|
||||
# Ratios that imply 0 error, so it *would* skip if retention allowed it
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=1.0, # Force retention for ALL steps
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
|
||||
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
|
||||
)
|
||||
|
||||
def test_mag_cache_tuple_outputs(self):
|
||||
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
|
||||
model = TupleTransformer()
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
|
||||
# Residual = 10.0
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
enc_t0 = torch.tensor([[[1.0]]])
|
||||
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
|
||||
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
|
||||
|
||||
# Step 1: Skip. Input 11.0.
|
||||
# Skipped Output = 11 + 10 = 21.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_reset(self):
|
||||
"""Test that state resets correctly after num_inference_steps."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
|
||||
)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
input_t = torch.ones(1, 1, 1)
|
||||
|
||||
model(input_t) # Step 0
|
||||
model(input_t) # Step 1 (Skipped)
|
||||
|
||||
# Step 2 (Reset -> Step 0) -> Should Compute
|
||||
# Input 2.0 -> Output 8.0
|
||||
input_t2 = torch.tensor([[[2.0]]])
|
||||
output_t2 = model(input_t2)
|
||||
|
||||
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
|
||||
|
||||
def test_mag_cache_calibration(self):
|
||||
"""Test that calibration mode records ratios."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# Ratio 0 is placeholder 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Check intermediate state
|
||||
ratios = self._get_calibration_data(model)
|
||||
self.assertEqual(len(ratios), 1)
|
||||
self.assertEqual(ratios[0], 1.0)
|
||||
|
||||
# Step 1
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# PrevResidual = 30. CurrResidual = 30.
|
||||
# Ratio = 30/30 = 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Verify it computes fully (no skip)
|
||||
# If it skipped, output would be 41.0. It should be 40.0
|
||||
# Actually in test setup, input is same (10.0) so output 40.0.
|
||||
# Let's ensure list is empty after reset (end of step 1)
|
||||
ratios_after = self._get_calibration_data(model)
|
||||
self.assertEqual(ratios_after, [])
|
||||
@@ -8,13 +8,6 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
OutputParam,
|
||||
generate_modular_model_card_content,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
@@ -342,239 +335,3 @@ class ModularGuiderTesterMixin:
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
def __init__(self, name, description):
|
||||
self.__class__.__name__ = name
|
||||
self.description = description
|
||||
self.sub_blocks = {}
|
||||
|
||||
return MockBlock(name, description)
|
||||
|
||||
def create_mock_blocks(
|
||||
self,
|
||||
class_name="TestBlocks",
|
||||
description="Test pipeline description",
|
||||
num_blocks=2,
|
||||
components=None,
|
||||
configs=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
trigger_inputs=None,
|
||||
model_name=None,
|
||||
):
|
||||
class MockBlocks:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = class_name
|
||||
self.description = description
|
||||
self.sub_blocks = {}
|
||||
self.expected_components = components or []
|
||||
self.expected_configs = configs or []
|
||||
self.inputs = inputs or []
|
||||
self.outputs = outputs or []
|
||||
self.trigger_inputs = trigger_inputs
|
||||
self.model_name = model_name
|
||||
|
||||
blocks = MockBlocks()
|
||||
|
||||
# Add mock sub-blocks
|
||||
for i in range(num_blocks):
|
||||
block_name = f"block_{i}"
|
||||
blocks.sub_blocks[block_name] = self.create_mock_block(f"Block{i}", f"Description for block {i}")
|
||||
|
||||
return blocks
|
||||
|
||||
def test_basic_model_card_content_structure(self):
|
||||
"""Test that all expected keys are present in the output."""
|
||||
blocks = self.create_mock_blocks()
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
expected_keys = [
|
||||
"pipeline_name",
|
||||
"model_description",
|
||||
"blocks_description",
|
||||
"components_description",
|
||||
"configs_section",
|
||||
"inputs_description",
|
||||
"outputs_description",
|
||||
"trigger_inputs_section",
|
||||
"tags",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in content, f"Expected key '{key}' not found in model card content"
|
||||
|
||||
assert isinstance(content["tags"], list), "Tags should be a list"
|
||||
|
||||
def test_pipeline_name_generation(self):
|
||||
"""Test that pipeline name is correctly generated from blocks class name."""
|
||||
blocks = self.create_mock_blocks(class_name="StableDiffusionBlocks")
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert content["pipeline_name"] == "StableDiffusion Pipeline"
|
||||
|
||||
def test_tags_generation_text_to_image(self):
|
||||
"""Test that text-to-image tags are correctly generated."""
|
||||
blocks = self.create_mock_blocks(trigger_inputs=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "modular-diffusers" in content["tags"]
|
||||
assert "diffusers" in content["tags"]
|
||||
assert "text-to-image" in content["tags"]
|
||||
|
||||
def test_tags_generation_with_trigger_inputs(self):
|
||||
"""Test that tags are correctly generated based on trigger inputs."""
|
||||
# Test inpainting
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["mask", "prompt"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
assert "inpainting" in content["tags"]
|
||||
|
||||
# Test image-to-image
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["image", "prompt"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
assert "image-to-image" in content["tags"]
|
||||
|
||||
# Test controlnet
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["control_image", "prompt"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
assert "controlnet" in content["tags"]
|
||||
|
||||
def test_tags_with_model_name(self):
|
||||
"""Test that model name is included in tags when present."""
|
||||
blocks = self.create_mock_blocks(model_name="stable-diffusion-xl")
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "stable-diffusion-xl" in content["tags"]
|
||||
|
||||
def test_components_description_formatting(self):
|
||||
"""Test that components are correctly formatted."""
|
||||
components = [
|
||||
ComponentSpec(name="vae", description="VAE component"),
|
||||
ComponentSpec(name="text_encoder", description="Text encoder component"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(components=components)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "vae" in content["components_description"]
|
||||
assert "text_encoder" in content["components_description"]
|
||||
# Should be enumerated
|
||||
assert "1." in content["components_description"]
|
||||
|
||||
def test_components_description_empty(self):
|
||||
"""Test handling of pipelines without components."""
|
||||
blocks = self.create_mock_blocks(components=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "No specific components required" in content["components_description"]
|
||||
|
||||
def test_configs_section_with_configs(self):
|
||||
"""Test that configs section is generated when configs are present."""
|
||||
configs = [
|
||||
ConfigSpec(name="num_train_timesteps", default=1000, description="Number of training timesteps"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(configs=configs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "## Configuration Parameters" in content["configs_section"]
|
||||
|
||||
def test_configs_section_empty(self):
|
||||
"""Test that configs section is empty when no configs are present."""
|
||||
blocks = self.create_mock_blocks(configs=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert content["configs_section"] == ""
|
||||
|
||||
def test_inputs_description_required_and_optional(self):
|
||||
"""Test that required and optional inputs are correctly formatted."""
|
||||
inputs = [
|
||||
InputParam(name="prompt", type_hint=str, required=True, description="The input prompt"),
|
||||
InputParam(name="num_steps", type_hint=int, required=False, default=50, description="Number of steps"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(inputs=inputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "**Required:**" in content["inputs_description"]
|
||||
assert "**Optional:**" in content["inputs_description"]
|
||||
assert "prompt" in content["inputs_description"]
|
||||
assert "num_steps" in content["inputs_description"]
|
||||
assert "default: `50`" in content["inputs_description"]
|
||||
|
||||
def test_inputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific inputs."""
|
||||
blocks = self.create_mock_blocks(inputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "No specific inputs defined" in content["inputs_description"]
|
||||
|
||||
def test_outputs_description_formatting(self):
|
||||
"""Test that outputs are correctly formatted."""
|
||||
outputs = [
|
||||
OutputParam(name="images", type_hint=torch.Tensor, description="Generated images"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(outputs=outputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "images" in content["outputs_description"]
|
||||
assert "Generated images" in content["outputs_description"]
|
||||
|
||||
def test_outputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific outputs."""
|
||||
blocks = self.create_mock_blocks(outputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "Standard pipeline outputs" in content["outputs_description"]
|
||||
|
||||
def test_trigger_inputs_section_with_triggers(self):
|
||||
"""Test that trigger inputs section is generated when present."""
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["mask", "image"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "### Conditional Execution" in content["trigger_inputs_section"]
|
||||
assert "`mask`" in content["trigger_inputs_section"]
|
||||
assert "`image`" in content["trigger_inputs_section"]
|
||||
|
||||
def test_trigger_inputs_section_empty(self):
|
||||
"""Test that trigger inputs section is empty when not present."""
|
||||
blocks = self.create_mock_blocks(trigger_inputs=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert content["trigger_inputs_section"] == ""
|
||||
|
||||
def test_blocks_description_with_sub_blocks(self):
|
||||
"""Test that blocks with sub-blocks are correctly described."""
|
||||
|
||||
class MockBlockWithSubBlocks:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = "ParentBlock"
|
||||
self.description = "Parent block"
|
||||
self.sub_blocks = {
|
||||
"child1": self.create_child_block("ChildBlock1", "Child 1 description"),
|
||||
"child2": self.create_child_block("ChildBlock2", "Child 2 description"),
|
||||
}
|
||||
|
||||
def create_child_block(self, name, desc):
|
||||
class ChildBlock:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = name
|
||||
self.description = desc
|
||||
|
||||
return ChildBlock()
|
||||
|
||||
blocks = self.create_mock_blocks()
|
||||
blocks.sub_blocks["parent"] = MockBlockWithSubBlocks()
|
||||
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "parent" in content["blocks_description"]
|
||||
assert "child1" in content["blocks_description"]
|
||||
assert "child2" in content["blocks_description"]
|
||||
|
||||
def test_model_description_includes_block_count(self):
|
||||
"""Test that model description includes the number of blocks."""
|
||||
blocks = self.create_mock_blocks(num_blocks=5)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "5-block architecture" in content["model_description"]
|
||||
|
||||
@@ -27,7 +27,6 @@ from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
@@ -42,7 +41,6 @@ class FluxPipelineFastTests(
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
|
||||
@@ -35,7 +35,6 @@ from diffusers import (
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.mag_cache import MagCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@@ -2977,59 +2976,6 @@ class TaylorSeerCacheTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class MagCacheTesterMixin:
|
||||
mag_cache_config = MagCacheConfig(
|
||||
threshold=0.06,
|
||||
max_skip_steps=3,
|
||||
retention_ratio=0.2,
|
||||
num_inference_steps=50,
|
||||
mag_ratios=torch.ones(50),
|
||||
)
|
||||
|
||||
def test_mag_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu"
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Match the config steps
|
||||
inputs["num_inference_steps"] = 50
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# 1. Run inference without MagCache (Baseline)
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 2. Run inference with MagCache ENABLED
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.mag_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 3. Run inference with MagCache DISABLED
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), (
|
||||
"MagCache outputs should not differ too much from baseline."
|
||||
)
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), (
|
||||
"Outputs after disabling cache should match original inference exactly."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user