Compare commits

..

30 Commits

Author SHA1 Message Date
YiYi Xu
b73cc50e48 Merge branch 'main' into modular-workflow 2026-01-31 09:51:11 -10:00
yiyixuxu
20c35da75c up up 2026-01-25 12:11:37 +01:00
yiyixuxu
6a549f5f55 initial support: workflow 2026-01-25 11:40:52 +01:00
Sayak Paul
412e51c856 include auto-docstring check in the modular ci. (#13004) 2026-01-23 22:34:24 -10:00
github-actions[bot]
23d06423ab Apply style fixes 2026-01-19 09:23:31 +00:00
YiYi Xu
aba551c868 Merge branch 'main' into modular-doc-improv 2026-01-18 23:20:36 -10:00
yiyixuxu
1f9576a2ca fix 2026-01-19 09:56:14 +01:00
yiyixuxu
d75fbc43c7 Merge branch 'modular-doc-improv' of github.com:huggingface/diffusers into modular-doc-improv 2026-01-19 09:54:46 +01:00
yiyixuxu
b7127ce7a7 revert change in z 2026-01-19 09:54:40 +01:00
YiYi Xu
7e9d2b954e Apply suggestions from code review 2026-01-18 22:44:44 -10:00
yiyixuxu
94525200fd rmove space in make docstring 2026-01-19 09:35:39 +01:00
yiyixuxu
f056af1fbb make style 2026-01-19 09:27:40 +01:00
yiyixuxu
8d45ff5bf6 apply auto docstring 2026-01-19 09:22:04 +01:00
yiyixuxu
fb15752d55 up up up 2026-01-19 08:10:31 +01:00
yiyixuxu
1f2dbc9dd2 up 2026-01-19 04:10:17 +01:00
yiyixuxu
002c3e8239 add template method 2026-01-19 03:24:34 +01:00
yiyixuxu
de03d7f100 refactor based on dhruv's feedback: remove the class method 2026-01-18 00:35:01 +01:00
yiyixuxu
25c968a38f add TODO in the description for empty docstring 2026-01-17 09:57:56 +01:00
yiyixuxu
aea0d046f6 address feedbacks 2026-01-17 09:36:58 +01:00
yiyixuxu
1c90ce33f2 up 2026-01-10 12:21:26 +01:00
yiyixuxu
507953f415 more more 2026-01-10 12:19:14 +01:00
yiyixuxu
f0555af1c6 up up up 2026-01-10 12:15:53 +01:00
yiyixuxu
2a81f2ec54 style 2026-01-10 12:15:36 +01:00
yiyixuxu
d20f413f78 more auto docstring 2026-01-10 12:11:28 +01:00
yiyixuxu
ff09bf1a63 add modular_auto_docstring! 2026-01-10 11:55:03 +01:00
yiyixuxu
34a743e2dc style 2026-01-10 10:57:27 +01:00
yiyixuxu
43ab14845d update outputs 2026-01-10 10:56:54 +01:00
YiYi Xu
fbfe5c8d6b Merge branch 'main' into modular-doc-improv 2026-01-09 23:54:23 -10:00
yiyixuxu
b29873dee7 up up 2026-01-10 10:52:53 +01:00
yiyixuxu
7b499de6d0 up 2026-01-10 03:35:15 +01:00
20 changed files with 1028 additions and 1027 deletions

View File

@@ -415,7 +415,6 @@ else:
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinBaseModularPipeline",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
@@ -432,13 +431,8 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22Blocks",
"Wan22Image2VideoBlocks",
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanBlocks",
"WanImage2VideoAutoBlocks",
"WanImage2VideoModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
"ZImageAutoBlocks",
"ZImageModularPipeline",
@@ -1157,7 +1151,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
@@ -1174,13 +1167,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
ZImageAutoBlocks,
ZImageModularPipeline,

View File

@@ -45,16 +45,7 @@ else:
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = [
"WanBlocks",
"Wan22Blocks",
"WanImage2VideoAutoBlocks",
"Wan22Image2VideoBlocks",
"WanModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"Wan22Image2VideoModularPipeline",
]
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
@@ -67,7 +58,6 @@ else:
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -98,7 +88,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
@@ -123,16 +112,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import (
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
else:
import sys

View File

@@ -55,11 +55,7 @@ else:
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = [
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -105,7 +101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -57,36 +59,47 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
return num_channels_latents
class Flux2KleinModularPipeline(Flux2ModularPipeline):
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2-Klein (distilled model).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinAutoBlocks"
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein (base model).
A ModularPipeline for Flux2-Klein.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:

View File

@@ -39,8 +39,11 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
combine_inputs,
combine_outputs,
format_components,
format_configs,
format_workflow,
make_doc_string,
)
@@ -52,61 +55,19 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# map regular pipeline to modular pipeline class name
def _create_default_map_fn(pipeline_class_name: str):
"""Create a mapping function that always returns the same pipeline class."""
def _map_fn(config_dict=None):
return pipeline_class_name
return _map_fn
def _flux2_klein_map_fn(config_dict=None):
if config_dict is None:
return "Flux2KleinModularPipeline"
if "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinModularPipeline"
else:
return "Flux2KleinBaseModularPipeline"
def _wan_map_fn(config_dict=None):
if config_dict is None:
return "WanModularPipeline"
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22ModularPipeline"
else:
return "WanModularPipeline"
def _wan_i2v_map_fn(config_dict=None):
if config_dict is None:
return "WanImage2VideoModularPipeline"
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22Image2VideoModularPipeline"
else:
return "WanImage2VideoModularPipeline"
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
("wan", _wan_map_fn),
("wan-i2v", _wan_i2v_map_fn),
("flux", _create_default_map_fn("FluxModularPipeline")),
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
("flux2-klein", _flux2_klein_map_fn),
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
("z-image", _create_default_map_fn("ZImageModularPipeline")),
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
("z-image", "ZImageModularPipeline"),
]
)
@@ -285,6 +246,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_workflow_map = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -340,6 +302,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def outputs(self) -> List[OutputParam]:
return self._get_outputs()
# currentlyonly ConditionalPipelineBlocks and SequentialPipelineBlocks support `get_execution_blocks`
def get_execution_blocks(self, **kwargs):
"""
Get the block(s) that would execute given the inputs. Must be implemented by subclasses that support
conditional block selection.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
"""
raise NotImplementedError(f"`get_execution_blocks` is not implemented for {self.__class__.__name__}")
# currently only SequentialPipelineBlocks support workflows
@property
def workflow_names(self):
"""
Returns a list of available workflow names. Must be implemented by subclasses that define `_workflow_map`.
"""
raise NotImplementedError(f"`workflow_names` is not implemented for {self.__class__.__name__}")
def get_workflow(self, workflow_name: str):
"""
Get the execution blocks for a specific workflow. Must be implemented by subclasses that define
`_workflow_map`.
Args:
workflow_name: Name of the workflow to retrieve.
"""
raise NotImplementedError(f"`get_workflow` is not implemented for {self.__class__.__name__}")
@classmethod
def from_pretrained(
cls,
@@ -408,8 +399,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
"""
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn()
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
@@ -478,72 +468,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
@staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
current default value is None and new default value is not None. Warns if multiple non-None default values
exist for the same input.
Args:
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
Returns:
List[InputParam]: Combined list of unique InputParam objects
"""
combined_dict = {} # name -> InputParam
value_sources = {} # name -> block_name
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (
current_param.default is not None
and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn(
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
@staticmethod
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
"""
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
occurrence of each output name.
Args:
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
Returns:
List[OutputParam]: Combined list of unique OutputParam objects
"""
combined_dict = {} # name -> OutputParam
for block_name, outputs in named_output_lists:
for output_param in outputs:
if (output_param.name not in combined_dict) or (
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs if input_param.name is not None]
@@ -575,7 +499,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
class ConditionalPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
`select_block` method to define the logic for selecting the block.
`select_block` method to define the logic for selecting the block. Currently, we only support selection logic based
on the presence or absence of inputs (i.e., whether they are `None` or not)
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -583,15 +508,20 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
block_names: List of prefixes for each block
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
block_classes: List of block classes to be used. Must have the same length as `block_names`.
block_names: List of names for each block. Must have the same length as `block_classes`.
block_trigger_inputs: List of input names that `select_block()` uses to determine which block to run.
For `ConditionalPipelineBlocks`, this does not need to correspond to `block_names` and `block_classes`. For
`AutoPipelineBlocks`, this must have the same length as `block_names` and `block_classes`, where each
element specifies the trigger input for the corresponding block.
default_block_name: Name of the default block to run when no trigger inputs match.
If None, this block can be skipped entirely when no trigger inputs are provided.
"""
block_classes = []
block_names = []
block_trigger_inputs = []
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
default_block_name = None
def __init__(self):
sub_blocks = InsertableDict()
@@ -655,7 +585,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
@property
def inputs(self) -> List[Tuple[str, Any]]:
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
combined_inputs = self.combine_inputs(*named_inputs)
combined_inputs = combine_inputs(*named_inputs)
# mark Required inputs only if that input is required by all the blocks
for input_param in combined_inputs:
if input_param.name in self.required_inputs:
@@ -667,15 +597,16 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
@property
def outputs(self) -> List[str]:
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
# used for `__repr__`
def _get_trigger_inputs(self) -> set:
"""
Returns a set of all unique trigger input values found in this block and nested blocks.
@@ -704,11 +635,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
return all_triggers
@property
def trigger_inputs(self):
"""All trigger inputs including from nested blocks."""
return self._get_trigger_inputs()
def select_block(self, **kwargs) -> Optional[str]:
"""
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
@@ -748,6 +674,39 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
logger.error(error_msg)
raise
def get_execution_blocks(self, **kwargs) -> Optional["ModularPipelineBlocks"]:
"""
Get the block(s) that would execute given the inputs.
Recursively resolves nested ConditionalPipelineBlocks until reaching either:
- A leaf block (no sub_blocks) → returns single `ModularPipelineBlocks`
- A `SequentialPipelineBlocks` → delegates to its `get_execution_blocks()` which returns
a `SequentialPipelineBlocks` containing the resolved execution blocks
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Returns:
- `ModularPipelineBlocks`: A leaf block or resolved `SequentialPipelineBlocks`
- `None`: If this block would be skipped (no trigger matched and no default)
"""
trigger_kwargs = {name: kwargs.get(name) for name in self.block_trigger_inputs if name is not None}
block_name = self.select_block(**trigger_kwargs)
if block_name is None:
block_name = self.default_block_name
if block_name is None:
return None
block = self.sub_blocks[block_name]
# Recursively resolve until we hit a leaf block or a SequentialPipelineBlocks
if block.sub_blocks:
return block.get_execution_blocks(**kwargs)
return block
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
@@ -755,11 +714,11 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
)
if self.trigger_inputs:
if self._get_trigger_inputs():
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
header += f" Trigger Inputs: {sorted(self._get_trigger_inputs())}\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -826,24 +785,56 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
class AutoPipelineBlocks(ConditionalPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
This is a specialized version of `ConditionalPipelineBlocks` where:
- Each block has one corresponding trigger input (1:1 mapping)
- Block selection is automatic: the first block whose trigger input is present gets selected
- `block_trigger_inputs` must have the same length as `block_names` and `block_classes`
- Use `None` in `block_trigger_inputs` to specify the default block, i.e the block that will run if no trigger
inputs are present
Attributes:
block_classes:
List of block classes to be used. Must have the same length as `block_names` and
`block_trigger_inputs`.
block_names:
List of names for each block. Must have the same length as `block_classes` and `block_trigger_inputs`.
block_trigger_inputs:
List of input names where each element specifies the trigger input for the corresponding block. Use
`None` to mark the default block.
Example:
```python
class MyAutoBlock(AutoPipelineBlocks):
block_classes = [InpaintEncoderBlock, ImageEncoderBlock, TextEncoderBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask_image", "image", None] # text2img is the default
```
With this definition:
- As long as `mask_image` is provided, "inpaint" block runs (regardless of `image` being provided or not)
- If `mask_image` is not provided but `image` is provided, "img2img" block runs
- Otherwise, "text2img" block runs (default, trigger is `None`)
"""
def __init__(self):
super().__init__()
if self.default_block_name is not None:
raise ValueError(
f"In {self.__class__.__name__}, do not set `default_block_name` for AutoPipelineBlocks. "
f"Use `None` in `block_trigger_inputs` to specify the default block."
)
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
)
@property
def default_block_name(self) -> Optional[str]:
"""Derive default_block_name from block_trigger_inputs (None entry)."""
if None in self.block_trigger_inputs:
idx = self.block_trigger_inputs.index(None)
return self.block_names[idx]
return None
self.default_block_name = self.block_names[idx]
def select_block(self, **kwargs) -> Optional[str]:
"""Select block based on which trigger input is present (not None)."""
@@ -897,6 +888,29 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs.append(config)
return expected_configs
@property
def workflow_names(self):
if self._workflow_map is None:
raise NotImplementedError(
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
)
return list(self._workflow_map.keys())
def get_workflow(self, workflow_name: str):
if self._workflow_map is None:
raise NotImplementedError(
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
)
if workflow_name not in self._workflow_map:
raise ValueError(f"Workflow {workflow_name} not found in {self.__class__.__name__}")
trigger_inputs = self._workflow_map[workflow_name]
workflow_blocks = self.get_execution_blocks(**trigger_inputs)
return workflow_blocks
@classmethod
def from_blocks_dict(
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
@@ -992,7 +1006,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# filter out them here so they do not end up as intermediate_outputs
if name not in inp_names:
named_outputs.append((name, block.intermediate_outputs))
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
# YiYi TODO: I think we can remove the outputs property
@@ -1016,6 +1030,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
raise
return pipeline, state
# used for `trigger_inputs` property
def _get_trigger_inputs(self):
"""
Returns a set of all unique trigger input values found in the blocks.
@@ -1039,89 +1054,50 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return fn_recursive_get_trigger(self.sub_blocks)
@property
def trigger_inputs(self):
return self._get_trigger_inputs()
def _traverse_trigger_blocks(self, active_inputs):
def get_execution_blocks(self, **kwargs) -> "SequentialPipelineBlocks":
"""
Traverse blocks and select which ones would run given the active inputs.
Get the blocks that would execute given the specified inputs.
Args:
active_inputs: Dict of input names to values that are "present"
**kwargs: Input names and values. Only trigger inputs affect block selection.
Returns:
OrderedDict of block_name -> block that would execute
SequentialPipelineBlocks containing only the blocks that would execute
"""
# Copy kwargs so we can add outputs as we traverse
active_inputs = dict(kwargs)
def fn_recursive_traverse(block, block_name, active_inputs):
result_blocks = OrderedDict()
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
if isinstance(block, ConditionalPipelineBlocks):
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
selected_block_name = block.select_block(**trigger_kwargs)
if selected_block_name is None:
selected_block_name = block.default_block_name
if selected_block_name is None:
block = block.get_execution_blocks(**active_inputs)
if block is None:
return result_blocks
selected_block = block.sub_blocks[selected_block_name]
if selected_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
else:
result_blocks[block_name] = selected_block
if hasattr(selected_block, "outputs"):
for out in selected_block.outputs:
active_inputs[out.name] = True
return result_blocks
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
if block.sub_blocks:
# Has sub_blocks (SequentialPipelineBlocks/ConditionalPipelineBlocks)
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
for sub_block_name, sub_block in block.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
nested_blocks = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
nested_blocks = {f"{block_name}.{k}": v for k, v in nested_blocks.items()}
result_blocks.update(nested_blocks)
else:
# Leaf block: single ModularPipelineBlocks or LoopSequentialPipelineBlocks
result_blocks[block_name] = block
if hasattr(block, "outputs"):
for out in block.outputs:
# Add outputs to active_inputs so subsequent blocks can use them as triggers
if hasattr(block, "intermediate_outputs"):
for out in block.intermediate_outputs:
active_inputs[out.name] = True
return result_blocks
all_blocks = OrderedDict()
for block_name, block in self.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(blocks_to_update)
return all_blocks
nested_blocks = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(nested_blocks)
def get_execution_blocks(self, **kwargs):
"""
Get the blocks that would execute given the specified inputs.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Pass any inputs that would be non-None at runtime.
Returns:
SequentialPipelineBlocks containing only the blocks that would execute
Example:
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
image=image)
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
"""
# Filter out None values
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
return SequentialPipelineBlocks.from_blocks_dict(all_blocks)
def __repr__(self):
class_name = self.__class__.__name__
@@ -1130,18 +1106,23 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
)
if self.trigger_inputs:
if self._workflow_map is None and self._get_trigger_inputs():
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
header += f" Trigger Inputs: {[inp for inp in self._get_trigger_inputs() if inp is not None]}\n"
# Get first trigger input as example
example_input = next(t for t in self.trigger_inputs if t is not None)
example_input = next(t for t in self._get_trigger_inputs() if t is not None)
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
header += " " + "=" * 100 + "\n\n"
description = self.description
if self._workflow_map is not None:
workflow_str = format_workflow(self._workflow_map)
description = f"{self.description}\n\n{workflow_str}"
# Format description with proper indentation
desc_lines = self.description.split("\n")
desc_lines = description.split("\n")
desc = []
# First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}")
@@ -1189,10 +1170,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
@property
def doc(self):
description = self.description
if self._workflow_map is not None:
workflow_str = format_workflow(self._workflow_map)
description = f"{self.description}\n\n{workflow_str}"
return make_doc_string(
self.inputs,
self.outputs,
self.description,
description=description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs,
@@ -1325,7 +1311,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
combined_outputs = self.combine_outputs(*named_outputs)
combined_outputs = combine_outputs(*named_outputs)
for output in self.loop_intermediate_outputs:
if output.name not in {output.name for output in combined_outputs}:
combined_outputs.append(output)
@@ -1588,7 +1574,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if modular_config_dict is not None:
blocks_class_name = modular_config_dict.get("_blocks_class_name")
else:
blocks_class_name = self.default_blocks_name
blocks_class_name = self.get_default_blocks_name(config_dict)
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1660,6 +1646,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default
return params
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
return self.default_blocks_name
@classmethod
def _load_pipeline_config(
cls,
@@ -1755,8 +1744,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn(config_dict)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:

View File

@@ -14,9 +14,10 @@
import inspect
import re
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import PIL.Image
import torch
@@ -860,6 +861,30 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
return "\n".join(formatted_configs)
def format_workflow(workflow_map):
"""Format a workflow map into a readable string representation.
Args:
workflow_map: Dictionary mapping workflow names to trigger inputs
Returns:
A formatted string representing all workflows
"""
if workflow_map is None:
return ""
lines = ["Supported workflows:"]
for workflow_name, trigger_inputs in workflow_map.items():
required_inputs = [k for k, v in trigger_inputs.items() if v]
if required_inputs:
inputs_str = ", ".join(f"`{t}`" for t in required_inputs)
lines.append(f" - `{workflow_name}`: requires {inputs_str}")
else:
lines.append(f" - `{workflow_name}`: default (no additional inputs required)")
return "\n".join(lines)
def make_doc_string(
inputs,
outputs,
@@ -914,3 +939,69 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)
return output
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
same input.
Args:
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
Returns:
List[InputParam]: Combined list of unique InputParam objects
"""
combined_dict = {} # name -> InputParam
value_sources = {} # name -> block_name
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (
current_param.default is not None
and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn(
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
"""
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
occurrence of each output name.
Args:
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
Returns:
List[OutputParam]: Combined list of unique OutputParam objects
"""
combined_dict = {} # name -> OutputParam
for block_name, outputs in named_output_lists:
for output_param in outputs:
if (output_param.name not in combined_dict) or (
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())

View File

@@ -1113,10 +1113,14 @@ AUTO_BLOCKS = InsertableDict(
class QwenImageAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
- for image-to-image generation, you need to provide `image`
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
- to run the controlnet workflow, you need to provide `control_image`
- for text-to-image generation, all you need to provide is `prompt`
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `prompt`, `image`
- `inpainting`: requires `prompt`, `mask_image`, `image`
- `controlnet_text2image`: requires `prompt`, `control_image`
- `controlnet_image2image`: requires `prompt`, `image`, `control_image`
- `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image`
Components:
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
@@ -1197,15 +1201,24 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
# Workflow map defines the trigger conditions for each workflow.
# How to define:
# - Only include required inputs and trigger inputs (inputs that determine which blocks run)
# - `True` means the workflow triggers when the input is not None (most common case)
# - Use specific values (e.g., `{"strength": 0.5}`) if your `select_block` logic depends on the value
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"prompt": True, "image": True},
"inpainting": {"prompt": True, "mask_image": True, "image": True},
"controlnet_text2image": {"prompt": True, "control_image": True},
"controlnet_image2image": {"prompt": True, "image": True, "control_image": True},
"controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True},
}
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage."
@property
def outputs(self):

View File

@@ -21,16 +21,16 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
_import_structure["modular_pipeline"] = [
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanAutoImageEncoderStep",
"WanAutoVaeImageEncoderStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -39,16 +39,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_wan import WanBlocks
from .modular_blocks_wan22 import Wan22Blocks
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
from .modular_pipeline import (
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanImage2VideoModularPipeline,
WanModularPipeline,
from .decoders import WanImageVaeDecoderStep
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
Wan22AutoBlocks,
WanAutoBlocks,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
)
from .modular_pipeline import WanModularPipeline
else:
import sys

View File

@@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
def __init__(
self,
image_latent_inputs: List[str] = ["image_condition_latents"],
image_latent_inputs: List[str] = ["first_frame_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
@@ -294,16 +294,20 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
a single string or list of strings. Defaults to ["image_condition_latents"].
a single string or list of strings. Defaults to ["first_frame_latents"].
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to [].
Examples:
# Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to
process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"]
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
# Configure to process multiple image latent inputs
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
)
"""
if not isinstance(image_latent_inputs, list):
@@ -553,3 +557,81 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.first_last_frame_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanVaeDecoderStep(ModularPipelineBlocks):
class WanImageVaeDecoderStep(ModularPipelineBlocks):
model_name = "wan"
@property

View File

@@ -89,10 +89,52 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_condition_latents",
"first_frame_latents",
required=True,
type_hint=torch.Tensor,
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
block_state.dtype
)
return components, block_state
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_last_frame_latents",
required=True,
type_hint=torch.Tensor,
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
),
InputParam(
"dtype",
@@ -105,7 +147,7 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat(
[block_state.latents, block_state.image_condition_latents], dim=1
[block_state.latents, block_state.first_last_frame_latents], dim=1
).to(block_state.dtype)
return components, block_state
@@ -542,3 +584,29 @@ class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
" - `WanLoopAfterDenoiser`\n"
"This block supports image-to-video tasks for Wan2.2."
)
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanFLF2VLoopBeforeDenoiser,
WanLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_image": "image_embeds",
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanFLF2VLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports FLF2V tasks for wan2.1."
)

View File

@@ -468,7 +468,7 @@ class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
return components, state
class WanVaeEncoderStep(ModularPipelineBlocks):
class WanVaeImageEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -493,7 +493,7 @@ class WanVaeEncoderStep(ModularPipelineBlocks):
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames", type_hint=int, default=81),
InputParam("num_frames"),
InputParam("generator"),
]
@@ -564,51 +564,7 @@ class WanVaeEncoderStep(ModularPipelineBlocks):
return components, state
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -634,7 +590,7 @@ class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames", type_hint=int, default=81),
InputParam("num_frames"),
InputParam("generator"),
]
@@ -711,49 +667,3 @@ class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int, required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.image_condition_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,474 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanImageVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
Wan22Image2VideoDenoiseStep,
WanDenoiseStep,
WanFLF2VDenoiseStep,
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeImageEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanTextEncoderStep,
WanVaeImageEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# wan2.1
# wan2.1: text2vid
class WanCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: image2video
## image encoder
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
## vae encoder
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
block_names = ["image_resize", "vae_encoder"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
## denoise
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: FLF2v
## image encoder
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
## vae encoder
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
## denoise
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanFLF2VDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_last_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: auto blocks
## image encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## vae encoder
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## denoise
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanFLF2VCoreDenoiseStep,
WanImage2VideoCoreDenoiseStep,
WanCoreDenoiseStep,
]
block_names = ["flf2v", "image2video", "text2video"]
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
)
# auto pipeline blocks
class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
WanAutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# wan22
# wan2.2: text2vid
## denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# wan2.2: image2video
## denoise
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
Wan22Image2VideoCoreDenoiseStep,
Wan22CoreDenoiseStep,
]
block_names = ["image2video", "text2video"]
block_trigger_inputs = ["first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
)
class Wan22AutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoVaeImageEncoderStep,
Wan22AutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# presets for wan2.1 and wan2.2
# YiYi Notes: should we move these to doc?
# wan2.1
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
("denoise", WanImage2VideoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
FLF2V_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
("vae_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
("denoise", WanFLF2VDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# wan2.2 presets
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# presets all blocks (wan and wan22)
ALL_BLOCKS = {
"wan2.1": {
"text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"flf2v": FLF2V_BLOCKS,
"auto": AUTO_BLOCKS,
},
"wan2.2": {
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
"auto": AUTO_BLOCKS_WAN22,
},
}

View File

@@ -1,83 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
WanDenoiseStep,
)
from .encoders import (
WanTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
class WanCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# ====================
# 2. BLOCKS (Wan2.1 text2video)
# ====================
class WanBlocks(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextEncoderStep,
WanCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = ["text_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Modular pipeline blocks for Wan2.1.\n"
+ "- `WanTextEncoderStep` is used to encode the text\n"
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
)

View File

@@ -1,88 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
)
from .encoders import (
WanTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
# 2. BLOCKS (Wan2.2 text2video)
# ====================
class Wan22Blocks(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextEncoderStep,
Wan22CoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Modular pipeline for text-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -1,117 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
Wan22Image2VideoDenoiseStep,
)
from .encoders import (
WanImageResizeStep,
WanPrepareFirstFrameLatentsStep,
WanTextEncoderStep,
WanVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. VAE ENCODER
# ====================
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
# ====================
# 2. DENOISE
# ====================
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
# 3. BLOCKS (Wan2.2 Image2Video)
# ====================
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
WanImage2VideoVaeEncoderStep,
Wan22Image2VideoCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Modular pipeline for image-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -1,203 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanTextEncoderStep,
WanVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. IMAGE ENCODER
# ====================
# wan2.1 I2V (first frame only)
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
# wan2.1 FLF2V (first and last frame)
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
# wan2.1 Auto Image Encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
model_name = "wan-i2v"
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
# ====================
# 2. VAE ENCODER
# ====================
# wan2.1 I2V (first frame only)
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
# wan2.1 FLF2V (first and last frame)
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanImageResizeStep,
WanImageCropResizeStep,
WanFirstLastFrameVaeEncoderStep,
WanPrepareFirstLastFrameLatentsStep,
]
block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
# wan2.1 Auto Vae Encoder
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
# ====================
# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise)
# ====================
# wan2.1 I2V core denoise (support both I2V and FLF2V)
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# ====================
# 4. BLOCKS (Wan2.1 Image2Video)
# ====================
# wan2.1 Image2Video Auto Blocks
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
WanAutoImageEncoderStep,
WanAutoVaeEncoderStep,
WanImage2VideoCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for image-to-video using Wan.\n"
+ "- for I2V workflow, all you need to provide is `image`"
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
)

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
@@ -28,12 +30,19 @@ class WanModularPipeline(
WanLoraLoaderMixin,
):
"""
A ModularPipeline for Wan2.1 text2video.
A ModularPipeline for Wan.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanBlocks"
default_blocks_name = "WanAutoBlocks"
# override the default_blocks_name in base class, which is just return self.default_blocks_name
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22AutoBlocks"
else:
return "WanAutoBlocks"
@property
def default_height(self):
@@ -109,33 +118,3 @@ class WanModularPipeline(
if hasattr(self, "scheduler") and self.scheduler is not None:
num_train_timesteps = self.scheduler.config.num_train_timesteps
return num_train_timesteps
class WanImage2VideoModularPipeline(WanModularPipeline):
"""
A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanImage2VideoAutoBlocks"
class Wan22ModularPipeline(WanModularPipeline):
"""
A ModularPipeline for Wan2.2 text2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Wan22Blocks"
class Wan22Image2VideoModularPipeline(Wan22ModularPipeline):
"""
A ModularPipeline for Wan2.2 image2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Wan22Image2VideoBlocks"

View File

@@ -246,7 +246,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan-i2v", WanImageToVideoPipeline),
("wan", WanImageToVideoPipeline),
]
)

View File

@@ -47,21 +47,6 @@ class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -302,7 +287,7 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Wan22Blocks(metaclass=DummyObject):
class Wan22AutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -317,82 +302,7 @@ class Wan22Blocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Wan22Image2VideoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Wan22Image2VideoModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Wan22ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImage2VideoAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImage2VideoModularPipeline(metaclass=DummyObject):
class WanAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):