mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-26 13:35:59 +08:00
Compare commits
4 Commits
cuda-128-g
...
modular-do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe4e4d7b66 | ||
|
|
56dd6cca15 | ||
|
|
809fc36e03 | ||
|
|
318f2bf918 |
@@ -24,7 +24,7 @@ The Modular Diffusers docs are organized as shown below.
|
||||
|
||||
## Quickstart
|
||||
|
||||
- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers.
|
||||
- A [quickstart](./quickstart) shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
|
||||
|
||||
## ModularPipelineBlocks
|
||||
|
||||
|
||||
@@ -12,333 +12,243 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Quickstart
|
||||
|
||||
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use.
|
||||
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface for running generation tasks.
|
||||
|
||||
This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework.
|
||||
This guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
|
||||
|
||||
## ModularPipelineBlocks
|
||||
## Run a pipeline
|
||||
|
||||
[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
|
||||
|
||||
- [`ModularPipelineBlocks`] is the most basic block for a single step.
|
||||
- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block.
|
||||
- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows.
|
||||
- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
|
||||
|
||||
[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep)
|
||||
])
|
||||
```
|
||||
|
||||
## Pipeline and block states
|
||||
|
||||
Modular Diffusers uses *state* to communicate data between blocks. There are two types of states.
|
||||
|
||||
- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks.
|
||||
- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block.
|
||||
|
||||
## Customizing blocks
|
||||
|
||||
[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two.
|
||||
|
||||
Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones.
|
||||
|
||||
Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
|
||||
|
||||
```py
|
||||
denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
|
||||
print(denoise_blocks)
|
||||
```
|
||||
|
||||
Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block.
|
||||
|
||||
```py
|
||||
# Copy existing blocks as placeholders
|
||||
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
|
||||
# ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
|
||||
|
||||
class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
```
|
||||
|
||||
### prepare_latents
|
||||
|
||||
The `prepare_latents` block requires the following changes.
|
||||
|
||||
- a processor to process the change map
|
||||
- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions
|
||||
- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`]
|
||||
|
||||
```diff
|
||||
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
|
||||
]
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
+ InputParam("diffdiff_map", required=True),
|
||||
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("timesteps", type_hint=torch.Tensor),
|
||||
+ InputParam("num_inference_steps", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
+ OutputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
def __call__(self, components, state: PipelineState):
|
||||
# ... existing logic ...
|
||||
+ # Process change map and create masks
|
||||
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
|
||||
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
|
||||
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
|
||||
+ block_state.original_latents = block_state.latents
|
||||
```
|
||||
|
||||
### denoise
|
||||
|
||||
The `before_denoiser` sub-block requires the following changes.
|
||||
|
||||
- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block
|
||||
- update the computation in the `__call__` method for applying Differential Diffusion
|
||||
|
||||
```diff
|
||||
class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("denoising_start"),
|
||||
+ InputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
def __call__(self, components, block_state, i, t):
|
||||
+ # Apply differential diffusion logic
|
||||
+ if i == 0 and block_state.denoising_start is None:
|
||||
+ block_state.latents = block_state.original_latents[:1]
|
||||
+ else:
|
||||
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
|
||||
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
|
||||
|
||||
# ... rest of existing logic ...
|
||||
```
|
||||
|
||||
## Assembling the blocks
|
||||
|
||||
You should have all the blocks you need at this point to create a [`ModularPipeline`].
|
||||
|
||||
Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter.
|
||||
|
||||
Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified.
|
||||
|
||||
Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`.
|
||||
|
||||
```py
|
||||
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
|
||||
|
||||
dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
|
||||
print(dd_blocks)
|
||||
```
|
||||
|
||||
## ModularPipeline
|
||||
|
||||
Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
|
||||
|
||||
It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ComponentsManager
|
||||
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
## Adding workflows
|
||||
|
||||
Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch.
|
||||
|
||||
This section demonstrates how to add an IP-Adapter or ControlNet.
|
||||
|
||||
### IP-Adapter
|
||||
|
||||
Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
|
||||
|
||||
ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
```
|
||||
|
||||
Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
```
|
||||
|
||||
Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
|
||||
|
||||
```py
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
dd_pipeline.loader.set_ip_adapter_scale(0.6)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
|
||||
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
image = dd_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
diffdiff_map=mask,
|
||||
image=image,
|
||||
output="images"
|
||||
)[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
Stable Diffusion XL already has a preset ControlNet block that can readily be used.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
|
||||
|
||||
control_input_block = StableDiffusionXLAutoControlNetInputStep()
|
||||
```
|
||||
|
||||
However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet.
|
||||
|
||||
Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`.
|
||||
|
||||
```py
|
||||
class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
```
|
||||
|
||||
Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
|
||||
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
|
||||
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
image = dd_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
diffdiff_map=mask,
|
||||
image=image,
|
||||
output="images"
|
||||
)[0]
|
||||
```
|
||||
|
||||
### AutoPipelineBlocks
|
||||
|
||||
The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion.
|
||||
|
||||
Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used.
|
||||
|
||||
```py
|
||||
class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
|
||||
block_names = ["controlnet_denoise", "denoise"]
|
||||
block_trigger_inputs = ["controlnet_cond", None]
|
||||
```
|
||||
|
||||
Add the `ip_adapter` and `controlnet_input` blocks.
|
||||
|
||||
```py
|
||||
DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
|
||||
DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
|
||||
DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
|
||||
```
|
||||
|
||||
Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run.
|
||||
|
||||
```py
|
||||
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
|
||||
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Share
|
||||
|
||||
Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`.
|
||||
|
||||
```py
|
||||
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
|
||||
```
|
||||
|
||||
Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`].
|
||||
[`ModularPipeline`] is the main interface for loading, running, and managing modular pipelines.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
components = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("Qwen/Qwen-Image")
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
|
||||
diffdiff_pipeline.load_components(torch_dtype=torch.float16)
|
||||
image = pipe(
|
||||
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration and knows where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.
|
||||
|
||||
Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.
|
||||
|
||||
## Understand the structure
|
||||
|
||||
The pipeline you loaded from `"Qwen/Qwen-Image"` is built from a [`ModularPipelineBlocks`] called `QwenImageAutoBlocks`. Print it to see its structure.
|
||||
|
||||
```py
|
||||
print(pipe.blocks)
|
||||
```
|
||||
|
||||
```
|
||||
QwenImageAutoBlocks(
|
||||
Class: SequentialPipelineBlocks
|
||||
|
||||
Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `prompt`, `image`
|
||||
- `inpainting`: requires `prompt`, `mask_image`, `image`
|
||||
- `controlnet_text2image`: requires `prompt`, `control_image`
|
||||
...
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`)
|
||||
vae (`AutoencoderKLQwenImage`)
|
||||
transformer (`QwenImageTransformer2DModel`)
|
||||
...
|
||||
|
||||
Sub-Blocks:
|
||||
[0] text_encoder (QwenImageAutoTextEncoderStep)
|
||||
[1] vae_encoder (QwenImageAutoVaeEncoderStep)
|
||||
[2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep)
|
||||
[3] denoise (QwenImageAutoCoreDenoiseStep)
|
||||
[4] decode (QwenImageAutoDecodeStep)
|
||||
)
|
||||
```
|
||||
|
||||
From this output you can see two things:
|
||||
- It supports multiple **workflows** (text2image, image2image, inpainting, etc.)
|
||||
- It's composed of **sub_blocks** (text_encoder, vae_encoder, denoise, decode)
|
||||
|
||||
### Workflows
|
||||
|
||||
This pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
input_image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
|
||||
|
||||
image = pipe(
|
||||
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
|
||||
image=input_image,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Learn more about conditional blocks in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
|
||||
|
||||
Use `get_workflow()` to extract the blocks for a specific workflow.
|
||||
|
||||
```py
|
||||
img2img_blocks = pipe.blocks.get_workflow("image2image")
|
||||
```
|
||||
|
||||
### Sub-blocks
|
||||
|
||||
Blocks are the building blocks of the modular system. They are *definitions* that specify the inputs, outputs, and computation logic for a step - and they can be composed together in different ways.
|
||||
|
||||
`QwenImageAutoBlocks` is itself composed of smaller blocks: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property.
|
||||
|
||||
Let's take a look at the `vae_encoder` block as an example. Use the `doc` property to see the full documentation for any block, including its inputs, outputs, and components.
|
||||
|
||||
```py
|
||||
vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
|
||||
print(vae_encoder_block.doc)
|
||||
```
|
||||
|
||||
Just like `QwenImageAutoBlocks`, this block can be converted to a pipeline and run on its own.
|
||||
```py
|
||||
vae_encoder_pipe = vae_encoder_block.init_pipeline()
|
||||
|
||||
# Reuse the VAE we already loaded, we can reuse it with update_componenets() method
|
||||
vae_encoder_pipe.update_components(vae=pipe.vae)
|
||||
|
||||
# Run just this block
|
||||
image_latents = vae_encoder_pipe(image=input_image).image_latents
|
||||
print(image_latents.shape)
|
||||
```
|
||||
|
||||
This reuses the VAE from our original pipeline instead of loading it again, keeping memory usage efficient. Learn more in the [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guide.
|
||||
|
||||
You can also add new blocks to compose new workflows. Let's add a canny edge detection block to create a ControlNet pipeline.
|
||||
|
||||
First, load the canny block from the Hub and insert it into the controlnet workflow. If you want to learn how to create your own custom blocks and share them on the Hub, check out the [Building Custom Blocks](https://huggingface.co/docs/diffusers/modular_diffusers/custom_blocks) guide.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks
|
||||
|
||||
# Load a canny block from the Hub
|
||||
canny_block = ModularPipelineBlocks.from_pretrained(
|
||||
"diffusers-internal-dev/canny-filtering",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
print(canny_block.doc)
|
||||
```
|
||||
```
|
||||
class CannyBlock
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, ndarray]`):
|
||||
Image to compute canny filter on
|
||||
low_threshold (`int`, *optional*, defaults to 50):
|
||||
Low threshold for the canny filter.
|
||||
high_threshold (`int`, *optional*, defaults to 200):
|
||||
High threshold for the canny filter.
|
||||
...
|
||||
|
||||
Outputs:
|
||||
control_image (`PIL.Image`):
|
||||
Canny map for input image
|
||||
```
|
||||
|
||||
The canny block takes an `image` as input and outputs a `control_image` - exactly what we need for ControlNet. Now get the controlnet workflow and check its inputs.
|
||||
|
||||
```py
|
||||
# Get the controlnet workflow
|
||||
blocks = pipe.blocks.get_workflow("controlnet_text2image")
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
control_image (`Image`):
|
||||
Control image for ControlNet conditioning.
|
||||
...
|
||||
```
|
||||
Notice it requires control_image as input. After inserting the canny block, the pipeline will accept a regular image instead.
|
||||
|
||||
```py
|
||||
# and insert canny at the beginning
|
||||
blocks.sub_blocks.insert("canny", canny_block, 0)
|
||||
|
||||
# Check the updated structure - notice the pipeline now takes "image" as input
|
||||
# even though it's a controlnet pipeline, because canny preprocesses it into control_image
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, ndarray]`):
|
||||
Image to compute canny filter on
|
||||
low_threshold (`int`, *optional*, defaults to 50):
|
||||
Low threshold for the canny filter.
|
||||
high_threshold (`int`, *optional*, defaults to 200):
|
||||
High threshold for the canny filter.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
...
|
||||
```
|
||||
|
||||
Now the pipeline takes `image` as input - the canny block will preprocess it into `control_image` automatically.
|
||||
|
||||
Create a pipeline from the modified blocks and load a ControlNet model.
|
||||
|
||||
```py
|
||||
pipeline = blocks.init_pipeline("Qwen/Qwen-Image")
|
||||
pipeline.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
# Load the ControlNet model
|
||||
controlnet_spec = pipeline.get_component_spec("controlnet")
|
||||
controlnet_spec.pretrained_model_name_or_path = "InstantX/Qwen-Image-ControlNet-Union"
|
||||
controlnet = controlnet_spec.load(torch_dtype=torch.bfloat16)
|
||||
pipeline.update_components(controlnet=controlnet)
|
||||
pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Now run the pipeline - the canny block preprocesses the image for ControlNet.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
|
||||
image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
).images[0]
|
||||
output
|
||||
```
|
||||
|
||||
## Next steps
|
||||
|
||||
<hfoptions id="next">
|
||||
<hfoption id="Build custom blocks">
|
||||
|
||||
Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Share components">
|
||||
|
||||
Use [`ComponentsManager`](./components_manager) to share models across multiple pipelines and manage memory efficiently.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Visual interface">
|
||||
|
||||
Connect modular pipelines to [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in Mellon guide
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Reference in New Issue
Block a user