mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-09 20:35:18 +08:00
Compare commits
30 Commits
wan-test-r
...
modular-wo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b73cc50e48 | ||
|
|
20c35da75c | ||
|
|
6a549f5f55 | ||
|
|
412e51c856 | ||
|
|
23d06423ab | ||
|
|
aba551c868 | ||
|
|
1f9576a2ca | ||
|
|
d75fbc43c7 | ||
|
|
b7127ce7a7 | ||
|
|
7e9d2b954e | ||
|
|
94525200fd | ||
|
|
f056af1fbb | ||
|
|
8d45ff5bf6 | ||
|
|
fb15752d55 | ||
|
|
1f2dbc9dd2 | ||
|
|
002c3e8239 | ||
|
|
de03d7f100 | ||
|
|
25c968a38f | ||
|
|
aea0d046f6 | ||
|
|
1c90ce33f2 | ||
|
|
507953f415 | ||
|
|
f0555af1c6 | ||
|
|
2a81f2ec54 | ||
|
|
d20f413f78 | ||
|
|
ff09bf1a63 | ||
|
|
34a743e2dc | ||
|
|
43ab14845d | ||
|
|
fbfe5c8d6b | ||
|
|
b29873dee7 | ||
|
|
7b499de6d0 |
@@ -39,8 +39,11 @@ from .modular_pipeline_utils import (
|
|||||||
InputParam,
|
InputParam,
|
||||||
InsertableDict,
|
InsertableDict,
|
||||||
OutputParam,
|
OutputParam,
|
||||||
|
combine_inputs,
|
||||||
|
combine_outputs,
|
||||||
format_components,
|
format_components,
|
||||||
format_configs,
|
format_configs,
|
||||||
|
format_workflow,
|
||||||
make_doc_string,
|
make_doc_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -243,6 +246,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
config_name = "modular_config.json"
|
config_name = "modular_config.json"
|
||||||
model_name = None
|
model_name = None
|
||||||
|
_workflow_map = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_signature_keys(cls, obj):
|
def _get_signature_keys(cls, obj):
|
||||||
@@ -298,6 +302,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
def outputs(self) -> List[OutputParam]:
|
def outputs(self) -> List[OutputParam]:
|
||||||
return self._get_outputs()
|
return self._get_outputs()
|
||||||
|
|
||||||
|
# currentlyonly ConditionalPipelineBlocks and SequentialPipelineBlocks support `get_execution_blocks`
|
||||||
|
def get_execution_blocks(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the block(s) that would execute given the inputs. Must be implemented by subclasses that support
|
||||||
|
conditional block selection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"`get_execution_blocks` is not implemented for {self.__class__.__name__}")
|
||||||
|
|
||||||
|
# currently only SequentialPipelineBlocks support workflows
|
||||||
|
@property
|
||||||
|
def workflow_names(self):
|
||||||
|
"""
|
||||||
|
Returns a list of available workflow names. Must be implemented by subclasses that define `_workflow_map`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"`workflow_names` is not implemented for {self.__class__.__name__}")
|
||||||
|
|
||||||
|
def get_workflow(self, workflow_name: str):
|
||||||
|
"""
|
||||||
|
Get the execution blocks for a specific workflow. Must be implemented by subclasses that define
|
||||||
|
`_workflow_map`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_name: Name of the workflow to retrieve.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"`get_workflow` is not implemented for {self.__class__.__name__}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -435,72 +468,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
if current_value is not param: # Using identity comparison to check if object was modified
|
if current_value is not param: # Using identity comparison to check if object was modified
|
||||||
state.set(param_name, param, input_param.kwargs_type)
|
state.set(param_name, param, input_param.kwargs_type)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
|
||||||
"""
|
|
||||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
|
|
||||||
current default value is None and new default value is not None. Warns if multiple non-None default values
|
|
||||||
exist for the same input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[InputParam]: Combined list of unique InputParam objects
|
|
||||||
"""
|
|
||||||
combined_dict = {} # name -> InputParam
|
|
||||||
value_sources = {} # name -> block_name
|
|
||||||
|
|
||||||
for block_name, inputs in named_input_lists:
|
|
||||||
for input_param in inputs:
|
|
||||||
if input_param.name is None and input_param.kwargs_type is not None:
|
|
||||||
input_name = "*_" + input_param.kwargs_type
|
|
||||||
else:
|
|
||||||
input_name = input_param.name
|
|
||||||
if input_name in combined_dict:
|
|
||||||
current_param = combined_dict[input_name]
|
|
||||||
if (
|
|
||||||
current_param.default is not None
|
|
||||||
and input_param.default is not None
|
|
||||||
and current_param.default != input_param.default
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
f"Multiple different default values found for input '{input_name}': "
|
|
||||||
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
|
||||||
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
|
||||||
)
|
|
||||||
if current_param.default is None and input_param.default is not None:
|
|
||||||
combined_dict[input_name] = input_param
|
|
||||||
value_sources[input_name] = block_name
|
|
||||||
else:
|
|
||||||
combined_dict[input_name] = input_param
|
|
||||||
value_sources[input_name] = block_name
|
|
||||||
|
|
||||||
return list(combined_dict.values())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
|
|
||||||
"""
|
|
||||||
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
|
|
||||||
occurrence of each output name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OutputParam]: Combined list of unique OutputParam objects
|
|
||||||
"""
|
|
||||||
combined_dict = {} # name -> OutputParam
|
|
||||||
|
|
||||||
for block_name, outputs in named_output_lists:
|
|
||||||
for output_param in outputs:
|
|
||||||
if (output_param.name not in combined_dict) or (
|
|
||||||
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
|
|
||||||
):
|
|
||||||
combined_dict[output_param.name] = output_param
|
|
||||||
|
|
||||||
return list(combined_dict.values())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_names(self) -> List[str]:
|
def input_names(self) -> List[str]:
|
||||||
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
||||||
@@ -532,7 +499,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||||
"""
|
"""
|
||||||
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
|
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
|
||||||
`select_block` method to define the logic for selecting the block.
|
`select_block` method to define the logic for selecting the block. Currently, we only support selection logic based
|
||||||
|
on the presence or absence of inputs (i.e., whether they are `None` or not)
|
||||||
|
|
||||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||||
@@ -540,15 +508,20 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
block_classes: List of block classes to be used
|
block_classes: List of block classes to be used. Must have the same length as `block_names`.
|
||||||
block_names: List of prefixes for each block
|
block_names: List of names for each block. Must have the same length as `block_classes`.
|
||||||
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
|
block_trigger_inputs: List of input names that `select_block()` uses to determine which block to run.
|
||||||
|
For `ConditionalPipelineBlocks`, this does not need to correspond to `block_names` and `block_classes`. For
|
||||||
|
`AutoPipelineBlocks`, this must have the same length as `block_names` and `block_classes`, where each
|
||||||
|
element specifies the trigger input for the corresponding block.
|
||||||
|
default_block_name: Name of the default block to run when no trigger inputs match.
|
||||||
|
If None, this block can be skipped entirely when no trigger inputs are provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
block_classes = []
|
block_classes = []
|
||||||
block_names = []
|
block_names = []
|
||||||
block_trigger_inputs = []
|
block_trigger_inputs = []
|
||||||
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
|
default_block_name = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
sub_blocks = InsertableDict()
|
sub_blocks = InsertableDict()
|
||||||
@@ -612,7 +585,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
@property
|
@property
|
||||||
def inputs(self) -> List[Tuple[str, Any]]:
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||||
combined_inputs = self.combine_inputs(*named_inputs)
|
combined_inputs = combine_inputs(*named_inputs)
|
||||||
# mark Required inputs only if that input is required by all the blocks
|
# mark Required inputs only if that input is required by all the blocks
|
||||||
for input_param in combined_inputs:
|
for input_param in combined_inputs:
|
||||||
if input_param.name in self.required_inputs:
|
if input_param.name in self.required_inputs:
|
||||||
@@ -624,15 +597,16 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
@property
|
@property
|
||||||
def intermediate_outputs(self) -> List[str]:
|
def intermediate_outputs(self) -> List[str]:
|
||||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||||
combined_outputs = self.combine_outputs(*named_outputs)
|
combined_outputs = combine_outputs(*named_outputs)
|
||||||
return combined_outputs
|
return combined_outputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outputs(self) -> List[str]:
|
def outputs(self) -> List[str]:
|
||||||
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
|
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
|
||||||
combined_outputs = self.combine_outputs(*named_outputs)
|
combined_outputs = combine_outputs(*named_outputs)
|
||||||
return combined_outputs
|
return combined_outputs
|
||||||
|
|
||||||
|
# used for `__repr__`
|
||||||
def _get_trigger_inputs(self) -> set:
|
def _get_trigger_inputs(self) -> set:
|
||||||
"""
|
"""
|
||||||
Returns a set of all unique trigger input values found in this block and nested blocks.
|
Returns a set of all unique trigger input values found in this block and nested blocks.
|
||||||
@@ -661,11 +635,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
|
|
||||||
return all_triggers
|
return all_triggers
|
||||||
|
|
||||||
@property
|
|
||||||
def trigger_inputs(self):
|
|
||||||
"""All trigger inputs including from nested blocks."""
|
|
||||||
return self._get_trigger_inputs()
|
|
||||||
|
|
||||||
def select_block(self, **kwargs) -> Optional[str]:
|
def select_block(self, **kwargs) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
|
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
|
||||||
@@ -705,6 +674,39 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_execution_blocks(self, **kwargs) -> Optional["ModularPipelineBlocks"]:
|
||||||
|
"""
|
||||||
|
Get the block(s) that would execute given the inputs.
|
||||||
|
|
||||||
|
Recursively resolves nested ConditionalPipelineBlocks until reaching either:
|
||||||
|
- A leaf block (no sub_blocks) → returns single `ModularPipelineBlocks`
|
||||||
|
- A `SequentialPipelineBlocks` → delegates to its `get_execution_blocks()` which returns
|
||||||
|
a `SequentialPipelineBlocks` containing the resolved execution blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- `ModularPipelineBlocks`: A leaf block or resolved `SequentialPipelineBlocks`
|
||||||
|
- `None`: If this block would be skipped (no trigger matched and no default)
|
||||||
|
"""
|
||||||
|
trigger_kwargs = {name: kwargs.get(name) for name in self.block_trigger_inputs if name is not None}
|
||||||
|
block_name = self.select_block(**trigger_kwargs)
|
||||||
|
|
||||||
|
if block_name is None:
|
||||||
|
block_name = self.default_block_name
|
||||||
|
|
||||||
|
if block_name is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
block = self.sub_blocks[block_name]
|
||||||
|
|
||||||
|
# Recursively resolve until we hit a leaf block or a SequentialPipelineBlocks
|
||||||
|
if block.sub_blocks:
|
||||||
|
return block.get_execution_blocks(**kwargs)
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
base_class = self.__class__.__bases__[0].__name__
|
base_class = self.__class__.__bases__[0].__name__
|
||||||
@@ -712,11 +714,11 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.trigger_inputs:
|
if self._get_trigger_inputs():
|
||||||
header += "\n"
|
header += "\n"
|
||||||
header += " " + "=" * 100 + "\n"
|
header += " " + "=" * 100 + "\n"
|
||||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||||
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
|
header += f" Trigger Inputs: {sorted(self._get_trigger_inputs())}\n"
|
||||||
header += " " + "=" * 100 + "\n\n"
|
header += " " + "=" * 100 + "\n\n"
|
||||||
|
|
||||||
# Format description with proper indentation
|
# Format description with proper indentation
|
||||||
@@ -783,24 +785,56 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
|||||||
|
|
||||||
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
||||||
"""
|
"""
|
||||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||||
|
|
||||||
|
This is a specialized version of `ConditionalPipelineBlocks` where:
|
||||||
|
- Each block has one corresponding trigger input (1:1 mapping)
|
||||||
|
- Block selection is automatic: the first block whose trigger input is present gets selected
|
||||||
|
- `block_trigger_inputs` must have the same length as `block_names` and `block_classes`
|
||||||
|
- Use `None` in `block_trigger_inputs` to specify the default block, i.e the block that will run if no trigger
|
||||||
|
inputs are present
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
block_classes:
|
||||||
|
List of block classes to be used. Must have the same length as `block_names` and
|
||||||
|
`block_trigger_inputs`.
|
||||||
|
block_names:
|
||||||
|
List of names for each block. Must have the same length as `block_classes` and `block_trigger_inputs`.
|
||||||
|
block_trigger_inputs:
|
||||||
|
List of input names where each element specifies the trigger input for the corresponding block. Use
|
||||||
|
`None` to mark the default block.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class MyAutoBlock(AutoPipelineBlocks):
|
||||||
|
block_classes = [InpaintEncoderBlock, ImageEncoderBlock, TextEncoderBlock]
|
||||||
|
block_names = ["inpaint", "img2img", "text2img"]
|
||||||
|
block_trigger_inputs = ["mask_image", "image", None] # text2img is the default
|
||||||
|
```
|
||||||
|
|
||||||
|
With this definition:
|
||||||
|
- As long as `mask_image` is provided, "inpaint" block runs (regardless of `image` being provided or not)
|
||||||
|
- If `mask_image` is not provided but `image` is provided, "img2img" block runs
|
||||||
|
- Otherwise, "text2img" block runs (default, trigger is `None`)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if self.default_block_name is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"In {self.__class__.__name__}, do not set `default_block_name` for AutoPipelineBlocks. "
|
||||||
|
f"Use `None` in `block_trigger_inputs` to specify the default block."
|
||||||
|
)
|
||||||
|
|
||||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def default_block_name(self) -> Optional[str]:
|
|
||||||
"""Derive default_block_name from block_trigger_inputs (None entry)."""
|
|
||||||
if None in self.block_trigger_inputs:
|
if None in self.block_trigger_inputs:
|
||||||
idx = self.block_trigger_inputs.index(None)
|
idx = self.block_trigger_inputs.index(None)
|
||||||
return self.block_names[idx]
|
self.default_block_name = self.block_names[idx]
|
||||||
return None
|
|
||||||
|
|
||||||
def select_block(self, **kwargs) -> Optional[str]:
|
def select_block(self, **kwargs) -> Optional[str]:
|
||||||
"""Select block based on which trigger input is present (not None)."""
|
"""Select block based on which trigger input is present (not None)."""
|
||||||
@@ -854,6 +888,29 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
expected_configs.append(config)
|
expected_configs.append(config)
|
||||||
return expected_configs
|
return expected_configs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workflow_names(self):
|
||||||
|
if self._workflow_map is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(self._workflow_map.keys())
|
||||||
|
|
||||||
|
def get_workflow(self, workflow_name: str):
|
||||||
|
if self._workflow_map is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if workflow_name not in self._workflow_map:
|
||||||
|
raise ValueError(f"Workflow {workflow_name} not found in {self.__class__.__name__}")
|
||||||
|
|
||||||
|
trigger_inputs = self._workflow_map[workflow_name]
|
||||||
|
workflow_blocks = self.get_execution_blocks(**trigger_inputs)
|
||||||
|
|
||||||
|
return workflow_blocks
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_blocks_dict(
|
def from_blocks_dict(
|
||||||
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
|
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
|
||||||
@@ -949,7 +1006,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
# filter out them here so they do not end up as intermediate_outputs
|
# filter out them here so they do not end up as intermediate_outputs
|
||||||
if name not in inp_names:
|
if name not in inp_names:
|
||||||
named_outputs.append((name, block.intermediate_outputs))
|
named_outputs.append((name, block.intermediate_outputs))
|
||||||
combined_outputs = self.combine_outputs(*named_outputs)
|
combined_outputs = combine_outputs(*named_outputs)
|
||||||
return combined_outputs
|
return combined_outputs
|
||||||
|
|
||||||
# YiYi TODO: I think we can remove the outputs property
|
# YiYi TODO: I think we can remove the outputs property
|
||||||
@@ -973,6 +1030,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
raise
|
raise
|
||||||
return pipeline, state
|
return pipeline, state
|
||||||
|
|
||||||
|
# used for `trigger_inputs` property
|
||||||
def _get_trigger_inputs(self):
|
def _get_trigger_inputs(self):
|
||||||
"""
|
"""
|
||||||
Returns a set of all unique trigger input values found in the blocks.
|
Returns a set of all unique trigger input values found in the blocks.
|
||||||
@@ -996,89 +1054,50 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
|
|
||||||
return fn_recursive_get_trigger(self.sub_blocks)
|
return fn_recursive_get_trigger(self.sub_blocks)
|
||||||
|
|
||||||
@property
|
def get_execution_blocks(self, **kwargs) -> "SequentialPipelineBlocks":
|
||||||
def trigger_inputs(self):
|
|
||||||
return self._get_trigger_inputs()
|
|
||||||
|
|
||||||
def _traverse_trigger_blocks(self, active_inputs):
|
|
||||||
"""
|
"""
|
||||||
Traverse blocks and select which ones would run given the active inputs.
|
Get the blocks that would execute given the specified inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
active_inputs: Dict of input names to values that are "present"
|
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OrderedDict of block_name -> block that would execute
|
SequentialPipelineBlocks containing only the blocks that would execute
|
||||||
"""
|
"""
|
||||||
|
# Copy kwargs so we can add outputs as we traverse
|
||||||
|
active_inputs = dict(kwargs)
|
||||||
|
|
||||||
def fn_recursive_traverse(block, block_name, active_inputs):
|
def fn_recursive_traverse(block, block_name, active_inputs):
|
||||||
result_blocks = OrderedDict()
|
result_blocks = OrderedDict()
|
||||||
|
|
||||||
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
||||||
if isinstance(block, ConditionalPipelineBlocks):
|
if isinstance(block, ConditionalPipelineBlocks):
|
||||||
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
|
block = block.get_execution_blocks(**active_inputs)
|
||||||
selected_block_name = block.select_block(**trigger_kwargs)
|
if block is None:
|
||||||
|
|
||||||
if selected_block_name is None:
|
|
||||||
selected_block_name = block.default_block_name
|
|
||||||
|
|
||||||
if selected_block_name is None:
|
|
||||||
return result_blocks
|
return result_blocks
|
||||||
|
|
||||||
selected_block = block.sub_blocks[selected_block_name]
|
# Has sub_blocks (SequentialPipelineBlocks/ConditionalPipelineBlocks)
|
||||||
|
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
|
||||||
if selected_block.sub_blocks:
|
|
||||||
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
|
|
||||||
else:
|
|
||||||
result_blocks[block_name] = selected_block
|
|
||||||
if hasattr(selected_block, "outputs"):
|
|
||||||
for out in selected_block.outputs:
|
|
||||||
active_inputs[out.name] = True
|
|
||||||
|
|
||||||
return result_blocks
|
|
||||||
|
|
||||||
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
|
|
||||||
if block.sub_blocks:
|
|
||||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
nested_blocks = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
nested_blocks = {f"{block_name}.{k}": v for k, v in nested_blocks.items()}
|
||||||
result_blocks.update(blocks_to_update)
|
result_blocks.update(nested_blocks)
|
||||||
else:
|
else:
|
||||||
|
# Leaf block: single ModularPipelineBlocks or LoopSequentialPipelineBlocks
|
||||||
result_blocks[block_name] = block
|
result_blocks[block_name] = block
|
||||||
if hasattr(block, "outputs"):
|
# Add outputs to active_inputs so subsequent blocks can use them as triggers
|
||||||
for out in block.outputs:
|
if hasattr(block, "intermediate_outputs"):
|
||||||
|
for out in block.intermediate_outputs:
|
||||||
active_inputs[out.name] = True
|
active_inputs[out.name] = True
|
||||||
|
|
||||||
return result_blocks
|
return result_blocks
|
||||||
|
|
||||||
all_blocks = OrderedDict()
|
all_blocks = OrderedDict()
|
||||||
for block_name, block in self.sub_blocks.items():
|
for block_name, block in self.sub_blocks.items():
|
||||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
|
nested_blocks = fn_recursive_traverse(block, block_name, active_inputs)
|
||||||
all_blocks.update(blocks_to_update)
|
all_blocks.update(nested_blocks)
|
||||||
return all_blocks
|
|
||||||
|
|
||||||
def get_execution_blocks(self, **kwargs):
|
return SequentialPipelineBlocks.from_blocks_dict(all_blocks)
|
||||||
"""
|
|
||||||
Get the blocks that would execute given the specified inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
|
||||||
Pass any inputs that would be non-None at runtime.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SequentialPipelineBlocks containing only the blocks that would execute
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
|
|
||||||
image=image)
|
|
||||||
|
|
||||||
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
|
|
||||||
"""
|
|
||||||
# Filter out None values
|
|
||||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
|
|
||||||
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
@@ -1087,18 +1106,23 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.trigger_inputs:
|
if self._workflow_map is None and self._get_trigger_inputs():
|
||||||
header += "\n"
|
header += "\n"
|
||||||
header += " " + "=" * 100 + "\n"
|
header += " " + "=" * 100 + "\n"
|
||||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
header += f" Trigger Inputs: {[inp for inp in self._get_trigger_inputs() if inp is not None]}\n"
|
||||||
# Get first trigger input as example
|
# Get first trigger input as example
|
||||||
example_input = next(t for t in self.trigger_inputs if t is not None)
|
example_input = next(t for t in self._get_trigger_inputs() if t is not None)
|
||||||
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
||||||
header += " " + "=" * 100 + "\n\n"
|
header += " " + "=" * 100 + "\n\n"
|
||||||
|
|
||||||
|
description = self.description
|
||||||
|
if self._workflow_map is not None:
|
||||||
|
workflow_str = format_workflow(self._workflow_map)
|
||||||
|
description = f"{self.description}\n\n{workflow_str}"
|
||||||
|
|
||||||
# Format description with proper indentation
|
# Format description with proper indentation
|
||||||
desc_lines = self.description.split("\n")
|
desc_lines = description.split("\n")
|
||||||
desc = []
|
desc = []
|
||||||
# First line with "Description:" label
|
# First line with "Description:" label
|
||||||
desc.append(f" Description: {desc_lines[0]}")
|
desc.append(f" Description: {desc_lines[0]}")
|
||||||
@@ -1146,10 +1170,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self):
|
def doc(self):
|
||||||
|
description = self.description
|
||||||
|
if self._workflow_map is not None:
|
||||||
|
workflow_str = format_workflow(self._workflow_map)
|
||||||
|
description = f"{self.description}\n\n{workflow_str}"
|
||||||
|
|
||||||
return make_doc_string(
|
return make_doc_string(
|
||||||
self.inputs,
|
self.inputs,
|
||||||
self.outputs,
|
self.outputs,
|
||||||
self.description,
|
description=description,
|
||||||
class_name=self.__class__.__name__,
|
class_name=self.__class__.__name__,
|
||||||
expected_components=self.expected_components,
|
expected_components=self.expected_components,
|
||||||
expected_configs=self.expected_configs,
|
expected_configs=self.expected_configs,
|
||||||
@@ -1282,7 +1311,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
@property
|
@property
|
||||||
def intermediate_outputs(self) -> List[str]:
|
def intermediate_outputs(self) -> List[str]:
|
||||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||||
combined_outputs = self.combine_outputs(*named_outputs)
|
combined_outputs = combine_outputs(*named_outputs)
|
||||||
for output in self.loop_intermediate_outputs:
|
for output in self.loop_intermediate_outputs:
|
||||||
if output.name not in {output.name for output in combined_outputs}:
|
if output.name not in {output.name for output in combined_outputs}:
|
||||||
combined_outputs.append(output)
|
combined_outputs.append(output)
|
||||||
|
|||||||
@@ -14,9 +14,10 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
@@ -860,6 +861,30 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
|||||||
return "\n".join(formatted_configs)
|
return "\n".join(formatted_configs)
|
||||||
|
|
||||||
|
|
||||||
|
def format_workflow(workflow_map):
|
||||||
|
"""Format a workflow map into a readable string representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_map: Dictionary mapping workflow names to trigger inputs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string representing all workflows
|
||||||
|
"""
|
||||||
|
if workflow_map is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = ["Supported workflows:"]
|
||||||
|
for workflow_name, trigger_inputs in workflow_map.items():
|
||||||
|
required_inputs = [k for k, v in trigger_inputs.items() if v]
|
||||||
|
if required_inputs:
|
||||||
|
inputs_str = ", ".join(f"`{t}`" for t in required_inputs)
|
||||||
|
lines.append(f" - `{workflow_name}`: requires {inputs_str}")
|
||||||
|
else:
|
||||||
|
lines.append(f" - `{workflow_name}`: default (no additional inputs required)")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def make_doc_string(
|
def make_doc_string(
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
@@ -914,3 +939,69 @@ def make_doc_string(
|
|||||||
output += format_output_params(outputs, indent_level=2)
|
output += format_output_params(outputs, indent_level=2)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||||
|
"""
|
||||||
|
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
|
||||||
|
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
|
||||||
|
same input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[InputParam]: Combined list of unique InputParam objects
|
||||||
|
"""
|
||||||
|
combined_dict = {} # name -> InputParam
|
||||||
|
value_sources = {} # name -> block_name
|
||||||
|
|
||||||
|
for block_name, inputs in named_input_lists:
|
||||||
|
for input_param in inputs:
|
||||||
|
if input_param.name is None and input_param.kwargs_type is not None:
|
||||||
|
input_name = "*_" + input_param.kwargs_type
|
||||||
|
else:
|
||||||
|
input_name = input_param.name
|
||||||
|
if input_name in combined_dict:
|
||||||
|
current_param = combined_dict[input_name]
|
||||||
|
if (
|
||||||
|
current_param.default is not None
|
||||||
|
and input_param.default is not None
|
||||||
|
and current_param.default != input_param.default
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
f"Multiple different default values found for input '{input_name}': "
|
||||||
|
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
||||||
|
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
||||||
|
)
|
||||||
|
if current_param.default is None and input_param.default is not None:
|
||||||
|
combined_dict[input_name] = input_param
|
||||||
|
value_sources[input_name] = block_name
|
||||||
|
else:
|
||||||
|
combined_dict[input_name] = input_param
|
||||||
|
value_sources[input_name] = block_name
|
||||||
|
|
||||||
|
return list(combined_dict.values())
|
||||||
|
|
||||||
|
|
||||||
|
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
|
||||||
|
"""
|
||||||
|
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
|
||||||
|
occurrence of each output name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputParam]: Combined list of unique OutputParam objects
|
||||||
|
"""
|
||||||
|
combined_dict = {} # name -> OutputParam
|
||||||
|
|
||||||
|
for block_name, outputs in named_output_lists:
|
||||||
|
for output_param in outputs:
|
||||||
|
if (output_param.name not in combined_dict) or (
|
||||||
|
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
|
||||||
|
):
|
||||||
|
combined_dict[output_param.name] = output_param
|
||||||
|
|
||||||
|
return list(combined_dict.values())
|
||||||
|
|||||||
@@ -1113,10 +1113,14 @@ AUTO_BLOCKS = InsertableDict(
|
|||||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||||
"""
|
"""
|
||||||
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||||
- for image-to-image generation, you need to provide `image`
|
|
||||||
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
|
Supported workflows:
|
||||||
- to run the controlnet workflow, you need to provide `control_image`
|
- `text2image`: requires `prompt`
|
||||||
- for text-to-image generation, all you need to provide is `prompt`
|
- `image2image`: requires `prompt`, `image`
|
||||||
|
- `inpainting`: requires `prompt`, `mask_image`, `image`
|
||||||
|
- `controlnet_text2image`: requires `prompt`, `control_image`
|
||||||
|
- `controlnet_image2image`: requires `prompt`, `image`, `control_image`
|
||||||
|
- `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image`
|
||||||
|
|
||||||
Components:
|
Components:
|
||||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||||
@@ -1197,15 +1201,24 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
|||||||
block_classes = AUTO_BLOCKS.values()
|
block_classes = AUTO_BLOCKS.values()
|
||||||
block_names = AUTO_BLOCKS.keys()
|
block_names = AUTO_BLOCKS.keys()
|
||||||
|
|
||||||
|
# Workflow map defines the trigger conditions for each workflow.
|
||||||
|
# How to define:
|
||||||
|
# - Only include required inputs and trigger inputs (inputs that determine which blocks run)
|
||||||
|
# - `True` means the workflow triggers when the input is not None (most common case)
|
||||||
|
# - Use specific values (e.g., `{"strength": 0.5}`) if your `select_block` logic depends on the value
|
||||||
|
|
||||||
|
_workflow_map = {
|
||||||
|
"text2image": {"prompt": True},
|
||||||
|
"image2image": {"prompt": True, "image": True},
|
||||||
|
"inpainting": {"prompt": True, "mask_image": True, "image": True},
|
||||||
|
"controlnet_text2image": {"prompt": True, "control_image": True},
|
||||||
|
"controlnet_image2image": {"prompt": True, "image": True, "control_image": True},
|
||||||
|
"controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True},
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
return (
|
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage."
|
||||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
|
||||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
|
||||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
|
|
||||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
|
||||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outputs(self):
|
def outputs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user