mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 22:15:00 +08:00
Compare commits
8 Commits
enable-cp-
...
modular-do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
077b6972cf | ||
|
|
7dc454f466 | ||
|
|
b6d05bb69c | ||
|
|
8483c06993 | ||
|
|
fe4e4d7b66 | ||
|
|
56dd6cca15 | ||
|
|
809fc36e03 | ||
|
|
318f2bf918 |
@@ -24,7 +24,7 @@ The Modular Diffusers docs are organized as shown below.
|
||||
|
||||
## Quickstart
|
||||
|
||||
- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers.
|
||||
- The [quickstart](./quickstart) shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
|
||||
|
||||
## ModularPipelineBlocks
|
||||
|
||||
|
||||
@@ -12,333 +12,250 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Quickstart
|
||||
|
||||
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use.
|
||||
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface for running generation tasks.
|
||||
|
||||
This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework.
|
||||
This guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
|
||||
|
||||
## ModularPipelineBlocks
|
||||
## Run a pipeline
|
||||
|
||||
[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
|
||||
|
||||
- [`ModularPipelineBlocks`] is the most basic block for a single step.
|
||||
- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block.
|
||||
- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows.
|
||||
- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
|
||||
|
||||
[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep)
|
||||
])
|
||||
```
|
||||
|
||||
## Pipeline and block states
|
||||
|
||||
Modular Diffusers uses *state* to communicate data between blocks. There are two types of states.
|
||||
|
||||
- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks.
|
||||
- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block.
|
||||
|
||||
## Customizing blocks
|
||||
|
||||
[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two.
|
||||
|
||||
Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones.
|
||||
|
||||
Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
|
||||
|
||||
```py
|
||||
denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
|
||||
print(denoise_blocks)
|
||||
```
|
||||
|
||||
Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block.
|
||||
|
||||
```py
|
||||
# Copy existing blocks as placeholders
|
||||
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
|
||||
# ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
|
||||
|
||||
class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
```
|
||||
|
||||
### prepare_latents
|
||||
|
||||
The `prepare_latents` block requires the following changes.
|
||||
|
||||
- a processor to process the change map
|
||||
- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions
|
||||
- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`]
|
||||
|
||||
```diff
|
||||
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
|
||||
]
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
+ InputParam("diffdiff_map", required=True),
|
||||
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("timesteps", type_hint=torch.Tensor),
|
||||
+ InputParam("num_inference_steps", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
+ OutputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
def __call__(self, components, state: PipelineState):
|
||||
# ... existing logic ...
|
||||
+ # Process change map and create masks
|
||||
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
|
||||
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
|
||||
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
|
||||
+ block_state.original_latents = block_state.latents
|
||||
```
|
||||
|
||||
### denoise
|
||||
|
||||
The `before_denoiser` sub-block requires the following changes.
|
||||
|
||||
- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block
|
||||
- update the computation in the `__call__` method for applying Differential Diffusion
|
||||
|
||||
```diff
|
||||
class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("denoising_start"),
|
||||
+ InputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
def __call__(self, components, block_state, i, t):
|
||||
+ # Apply differential diffusion logic
|
||||
+ if i == 0 and block_state.denoising_start is None:
|
||||
+ block_state.latents = block_state.original_latents[:1]
|
||||
+ else:
|
||||
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
|
||||
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
|
||||
|
||||
# ... rest of existing logic ...
|
||||
```
|
||||
|
||||
## Assembling the blocks
|
||||
|
||||
You should have all the blocks you need at this point to create a [`ModularPipeline`].
|
||||
|
||||
Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter.
|
||||
|
||||
Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified.
|
||||
|
||||
Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`.
|
||||
|
||||
```py
|
||||
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
|
||||
|
||||
dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
|
||||
print(dd_blocks)
|
||||
```
|
||||
|
||||
## ModularPipeline
|
||||
|
||||
Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
|
||||
|
||||
It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ComponentsManager
|
||||
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
## Adding workflows
|
||||
|
||||
Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch.
|
||||
|
||||
This section demonstrates how to add an IP-Adapter or ControlNet.
|
||||
|
||||
### IP-Adapter
|
||||
|
||||
Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
|
||||
|
||||
ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
```
|
||||
|
||||
Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
```
|
||||
|
||||
Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
|
||||
|
||||
```py
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
dd_pipeline.loader.set_ip_adapter_scale(0.6)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
|
||||
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
image = dd_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
diffdiff_map=mask,
|
||||
image=image,
|
||||
output="images"
|
||||
)[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
Stable Diffusion XL already has a preset ControlNet block that can readily be used.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
|
||||
|
||||
control_input_block = StableDiffusionXLAutoControlNetInputStep()
|
||||
```
|
||||
|
||||
However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet.
|
||||
|
||||
Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`.
|
||||
|
||||
```py
|
||||
class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
```
|
||||
|
||||
Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
|
||||
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
|
||||
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
image = dd_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
diffdiff_map=mask,
|
||||
image=image,
|
||||
output="images"
|
||||
)[0]
|
||||
```
|
||||
|
||||
### AutoPipelineBlocks
|
||||
|
||||
The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion.
|
||||
|
||||
Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used.
|
||||
|
||||
```py
|
||||
class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
|
||||
block_names = ["controlnet_denoise", "denoise"]
|
||||
block_trigger_inputs = ["controlnet_cond", None]
|
||||
```
|
||||
|
||||
Add the `ip_adapter` and `controlnet_input` blocks.
|
||||
|
||||
```py
|
||||
DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
|
||||
DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
|
||||
DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
|
||||
```
|
||||
|
||||
Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run.
|
||||
|
||||
```py
|
||||
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
|
||||
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Share
|
||||
|
||||
Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`.
|
||||
|
||||
```py
|
||||
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
|
||||
```
|
||||
|
||||
Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`].
|
||||
[`ModularPipeline`] is the main interface for loading, running, and managing modular pipelines.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
components = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("Qwen/Qwen-Image")
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
|
||||
diffdiff_pipeline.load_components(torch_dtype=torch.float16)
|
||||
image = pipe(
|
||||
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.
|
||||
|
||||
Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.
|
||||
|
||||
## Understand the structure
|
||||
|
||||
A [`ModularPipeline`] has two parts:
|
||||
- **State**: the loaded components (models, schedulers, processors) and configuration
|
||||
- **Definition**: the [`ModularPipelineBlocks`] that specify inputs, outputs, expected components and computation logic
|
||||
|
||||
The blocks define *what* the pipeline does. Access them through `pipe.blocks`.
|
||||
```py
|
||||
print(pipe.blocks)
|
||||
```
|
||||
|
||||
```
|
||||
QwenImageAutoBlocks(
|
||||
Class: SequentialPipelineBlocks
|
||||
|
||||
Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `prompt`, `image`
|
||||
- `inpainting`: requires `prompt`, `mask_image`, `image`
|
||||
- `controlnet_text2image`: requires `prompt`, `control_image`
|
||||
...
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`)
|
||||
vae (`AutoencoderKLQwenImage`)
|
||||
transformer (`QwenImageTransformer2DModel`)
|
||||
...
|
||||
|
||||
Sub-Blocks:
|
||||
[0] text_encoder (QwenImageAutoTextEncoderStep)
|
||||
[1] vae_encoder (QwenImageAutoVaeEncoderStep)
|
||||
[2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep)
|
||||
[3] denoise (QwenImageAutoCoreDenoiseStep)
|
||||
[4] decode (QwenImageAutoDecodeStep)
|
||||
)
|
||||
```
|
||||
|
||||
The output returns:
|
||||
- The supported workflows (text2image, image2image, inpainting, etc.)
|
||||
- The Sub-Blocks it's composed of (text_encoder, vae_encoder, denoise, decode)
|
||||
|
||||
### Workflows
|
||||
|
||||
`QwenImageAutoBlocks` is a [`ConditionalPipelineBlocks`], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image.
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
input_image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
|
||||
|
||||
image = pipe(
|
||||
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
|
||||
image=input_image,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Use `get_workflow()` to extract the blocks for a specific workflow.
|
||||
```py
|
||||
img2img_blocks = pipe.blocks.get_workflow("image2image")
|
||||
```
|
||||
|
||||
Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
|
||||
|
||||
### Sub-blocks
|
||||
|
||||
`QwenImageAutoBlocks` is itself composed of smaller blocks: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property.
|
||||
|
||||
The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
|
||||
|
||||
```py
|
||||
vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
|
||||
print(vae_encoder_block.doc)
|
||||
```
|
||||
|
||||
This block can be converted to a pipeline and run on its own with [`~ModularPipelineBlocks.init_pipeline`].
|
||||
```py
|
||||
vae_encoder_pipe = vae_encoder_block.init_pipeline()
|
||||
|
||||
# Reuse the VAE we already loaded, we can reuse it with update_components() method
|
||||
vae_encoder_pipe.update_components(vae=pipe.vae)
|
||||
|
||||
# Run just this block
|
||||
image_latents = vae_encoder_pipe(image=input_image).image_latents
|
||||
print(image_latents.shape)
|
||||
```
|
||||
|
||||
It reuses the VAE from our original pipeline instead of reloading it, keeping memory usage efficient. Learn more in the [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guide.
|
||||
|
||||
Since blocks are composable, you can modify the pipeline's definition by adding, removing, or swapping blocks to create new workflows. In the next section, we'll add a canny edge detection block to a ControlNet pipeline, so you can pass a regular image instead of a pre-processed canny edge map.
|
||||
|
||||
## Compose new workflows
|
||||
|
||||
Let's add a canny edge detection block to a ControlNet pipeline. First, load a pre-built canny block from the Hub (see [Building Custom Blocks](https://huggingface.co/docs/diffusers/modular_diffusers/custom_blocks) to create your own).
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks
|
||||
|
||||
# Load a canny block from the Hub
|
||||
canny_block = ModularPipelineBlocks.from_pretrained(
|
||||
"diffusers-internal-dev/canny-filtering",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
print(canny_block.doc)
|
||||
```
|
||||
```
|
||||
class CannyBlock
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, ndarray]`):
|
||||
Image to compute canny filter on
|
||||
low_threshold (`int`, *optional*, defaults to 50):
|
||||
Low threshold for the canny filter.
|
||||
high_threshold (`int`, *optional*, defaults to 200):
|
||||
High threshold for the canny filter.
|
||||
...
|
||||
|
||||
Outputs:
|
||||
control_image (`PIL.Image`):
|
||||
Canny map for input image
|
||||
```
|
||||
|
||||
Use `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
|
||||
|
||||
```py
|
||||
# Get the controlnet workflow that we want to work with
|
||||
blocks = pipe.blocks.get_workflow("controlnet_text2image")
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
control_image (`Image`):
|
||||
Control image for ControlNet conditioning.
|
||||
...
|
||||
```
|
||||
It requires control_image as input. After inserting the canny block, the pipeline will accept a regular image instead.
|
||||
|
||||
```py
|
||||
# and insert canny at the beginning
|
||||
blocks.sub_blocks.insert("canny", canny_block, 0)
|
||||
|
||||
# Check the updated structure: CannyBlock is now listed as first sub-block
|
||||
print(blocks)
|
||||
# Check the updated doc: notice the pipeline now takes "image" as input
|
||||
# even though it's a controlnet pipeline, because canny preprocesses it into control_image
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, ndarray]`):
|
||||
Image to compute canny filter on
|
||||
low_threshold (`int`, *optional*, defaults to 50):
|
||||
Low threshold for the canny filter.
|
||||
high_threshold (`int`, *optional*, defaults to 200):
|
||||
High threshold for the canny filter.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
...
|
||||
```
|
||||
|
||||
Now the pipeline takes `image` as input - the canny block will preprocess it into `control_image` automatically.
|
||||
|
||||
Create a pipeline from the modified blocks and load a ControlNet model. We use [`ComponentsManager`] to enable CPU offloading for reduced memory usage (learn more in the [ComponentsManager](./components_manager) guide).
|
||||
```py
|
||||
from diffusers import ComponentsManager
|
||||
|
||||
manager = ComponentsManager()
|
||||
manager.enable_auto_cpu_offload(device="cuda:0")
|
||||
|
||||
pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager)
|
||||
|
||||
pipeline.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
# Load the ControlNet model
|
||||
controlnet_spec = pipeline.get_component_spec("controlnet")
|
||||
controlnet_spec.pretrained_model_name_or_path = "InstantX/Qwen-Image-ControlNet-Union"
|
||||
controlnet = controlnet_spec.load(torch_dtype=torch.bfloat16)
|
||||
pipeline.update_components(controlnet=controlnet)
|
||||
```
|
||||
|
||||
Now run the pipeline - the canny block preprocesses the image for ControlNet.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
|
||||
image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
).images[0]
|
||||
output
|
||||
```
|
||||
|
||||
## Next steps
|
||||
|
||||
<hfoptions id="next">
|
||||
<hfoption id="Build custom blocks">
|
||||
|
||||
Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Share components">
|
||||
|
||||
Use [`ComponentsManager`](./components_manager) to share models across multiple pipelines and manage memory efficiently.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Visual interface">
|
||||
|
||||
Connect modular pipelines to [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in Mellon guide.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Reference in New Issue
Block a user