mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-03 17:35:10 +08:00
Compare commits
30 Commits
modular-wa
...
modular-wo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b73cc50e48 | ||
|
|
20c35da75c | ||
|
|
6a549f5f55 | ||
|
|
412e51c856 | ||
|
|
23d06423ab | ||
|
|
aba551c868 | ||
|
|
1f9576a2ca | ||
|
|
d75fbc43c7 | ||
|
|
b7127ce7a7 | ||
|
|
7e9d2b954e | ||
|
|
94525200fd | ||
|
|
f056af1fbb | ||
|
|
8d45ff5bf6 | ||
|
|
fb15752d55 | ||
|
|
1f2dbc9dd2 | ||
|
|
002c3e8239 | ||
|
|
de03d7f100 | ||
|
|
25c968a38f | ||
|
|
aea0d046f6 | ||
|
|
1c90ce33f2 | ||
|
|
507953f415 | ||
|
|
f0555af1c6 | ||
|
|
2a81f2ec54 | ||
|
|
d20f413f78 | ||
|
|
ff09bf1a63 | ||
|
|
34a743e2dc | ||
|
|
43ab14845d | ||
|
|
fbfe5c8d6b | ||
|
|
b29873dee7 | ||
|
|
7b499de6d0 |
@@ -415,7 +415,6 @@ else:
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
@@ -432,13 +431,8 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22Blocks",
|
||||
"Wan22Image2VideoBlocks",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanBlocks",
|
||||
"WanImage2VideoAutoBlocks",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
"ZImageAutoBlocks",
|
||||
"ZImageModularPipeline",
|
||||
@@ -1157,7 +1151,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinBaseModularPipeline,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
@@ -1174,13 +1167,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22Blocks,
|
||||
Wan22Image2VideoBlocks,
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanBlocks,
|
||||
WanImage2VideoAutoBlocks,
|
||||
WanImage2VideoModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
ZImageAutoBlocks,
|
||||
ZImageModularPipeline,
|
||||
|
||||
@@ -45,16 +45,7 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = [
|
||||
"WanBlocks",
|
||||
"Wan22Blocks",
|
||||
"WanImage2VideoAutoBlocks",
|
||||
"Wan22Image2VideoBlocks",
|
||||
"WanModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -67,7 +58,6 @@ else:
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -98,7 +88,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinBaseModularPipeline,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
@@ -123,16 +112,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import (
|
||||
Wan22Blocks,
|
||||
Wan22Image2VideoBlocks,
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanBlocks,
|
||||
WanImage2VideoAutoBlocks,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -55,11 +55,7 @@ else:
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -105,7 +101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -57,36 +59,47 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(Flux2ModularPipeline):
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein (distilled model).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinAutoBlocks"
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein (base model).
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
|
||||
@@ -39,8 +39,11 @@ from .modular_pipeline_utils import (
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
combine_inputs,
|
||||
combine_outputs,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_workflow,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
@@ -52,61 +55,19 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# map regular pipeline to modular pipeline class name
|
||||
|
||||
|
||||
def _create_default_map_fn(pipeline_class_name: str):
|
||||
"""Create a mapping function that always returns the same pipeline class."""
|
||||
|
||||
def _map_fn(config_dict=None):
|
||||
return pipeline_class_name
|
||||
|
||||
return _map_fn
|
||||
|
||||
|
||||
def _flux2_klein_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "Flux2KleinModularPipeline"
|
||||
|
||||
if "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinModularPipeline"
|
||||
else:
|
||||
return "Flux2KleinBaseModularPipeline"
|
||||
|
||||
|
||||
def _wan_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "WanModularPipeline"
|
||||
|
||||
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22ModularPipeline"
|
||||
else:
|
||||
return "WanModularPipeline"
|
||||
|
||||
|
||||
def _wan_i2v_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22Image2VideoModularPipeline"
|
||||
else:
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
|
||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
|
||||
("wan", _wan_map_fn),
|
||||
("wan-i2v", _wan_i2v_map_fn),
|
||||
("flux", _create_default_map_fn("FluxModularPipeline")),
|
||||
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
|
||||
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
|
||||
("flux2-klein", _flux2_klein_map_fn),
|
||||
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
|
||||
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
|
||||
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
|
||||
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
|
||||
("z-image", "ZImageModularPipeline"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -285,6 +246,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
_workflow_map = None
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
@@ -340,6 +302,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
# currentlyonly ConditionalPipelineBlocks and SequentialPipelineBlocks support `get_execution_blocks`
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the block(s) that would execute given the inputs. Must be implemented by subclasses that support
|
||||
conditional block selection.
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
"""
|
||||
raise NotImplementedError(f"`get_execution_blocks` is not implemented for {self.__class__.__name__}")
|
||||
|
||||
# currently only SequentialPipelineBlocks support workflows
|
||||
@property
|
||||
def workflow_names(self):
|
||||
"""
|
||||
Returns a list of available workflow names. Must be implemented by subclasses that define `_workflow_map`.
|
||||
"""
|
||||
raise NotImplementedError(f"`workflow_names` is not implemented for {self.__class__.__name__}")
|
||||
|
||||
def get_workflow(self, workflow_name: str):
|
||||
"""
|
||||
Get the execution blocks for a specific workflow. Must be implemented by subclasses that define
|
||||
`_workflow_map`.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow to retrieve.
|
||||
"""
|
||||
raise NotImplementedError(f"`get_workflow` is not implemented for {self.__class__.__name__}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -408,8 +399,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
|
||||
"""
|
||||
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
|
||||
pipeline_class_name = map_fn()
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
|
||||
@@ -478,72 +468,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@staticmethod
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
|
||||
current default value is None and new default value is not None. Warns if multiple non-None default values
|
||||
exist for the same input.
|
||||
|
||||
Args:
|
||||
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
|
||||
|
||||
Returns:
|
||||
List[InputParam]: Combined list of unique InputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> InputParam
|
||||
value_sources = {} # name -> block_name
|
||||
|
||||
for block_name, inputs in named_input_lists:
|
||||
for input_param in inputs:
|
||||
if input_param.name is None and input_param.kwargs_type is not None:
|
||||
input_name = "*_" + input_param.kwargs_type
|
||||
else:
|
||||
input_name = input_param.name
|
||||
if input_name in combined_dict:
|
||||
current_param = combined_dict[input_name]
|
||||
if (
|
||||
current_param.default is not None
|
||||
and input_param.default is not None
|
||||
and current_param.default != input_param.default
|
||||
):
|
||||
warnings.warn(
|
||||
f"Multiple different default values found for input '{input_name}': "
|
||||
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
||||
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
||||
)
|
||||
if current_param.default is None and input_param.default is not None:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
else:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@staticmethod
|
||||
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
|
||||
"""
|
||||
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
|
||||
occurrence of each output name.
|
||||
|
||||
Args:
|
||||
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
|
||||
|
||||
Returns:
|
||||
List[OutputParam]: Combined list of unique OutputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> OutputParam
|
||||
|
||||
for block_name, outputs in named_output_lists:
|
||||
for output_param in outputs:
|
||||
if (output_param.name not in combined_dict) or (
|
||||
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
|
||||
):
|
||||
combined_dict[output_param.name] = output_param
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
||||
@@ -575,7 +499,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
|
||||
`select_block` method to define the logic for selecting the block.
|
||||
`select_block` method to define the logic for selecting the block. Currently, we only support selection logic based
|
||||
on the presence or absence of inputs (i.e., whether they are `None` or not)
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
@@ -583,15 +508,20 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
|
||||
Attributes:
|
||||
block_classes: List of block classes to be used
|
||||
block_names: List of prefixes for each block
|
||||
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
|
||||
block_classes: List of block classes to be used. Must have the same length as `block_names`.
|
||||
block_names: List of names for each block. Must have the same length as `block_classes`.
|
||||
block_trigger_inputs: List of input names that `select_block()` uses to determine which block to run.
|
||||
For `ConditionalPipelineBlocks`, this does not need to correspond to `block_names` and `block_classes`. For
|
||||
`AutoPipelineBlocks`, this must have the same length as `block_names` and `block_classes`, where each
|
||||
element specifies the trigger input for the corresponding block.
|
||||
default_block_name: Name of the default block to run when no trigger inputs match.
|
||||
If None, this block can be skipped entirely when no trigger inputs are provided.
|
||||
"""
|
||||
|
||||
block_classes = []
|
||||
block_names = []
|
||||
block_trigger_inputs = []
|
||||
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
|
||||
default_block_name = None
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
@@ -655,7 +585,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
combined_inputs = combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required by all the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
@@ -667,15 +597,16 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
# used for `__repr__`
|
||||
def _get_trigger_inputs(self) -> set:
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in this block and nested blocks.
|
||||
@@ -704,11 +635,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return all_triggers
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
"""All trigger inputs including from nested blocks."""
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
|
||||
@@ -748,6 +674,39 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get_execution_blocks(self, **kwargs) -> Optional["ModularPipelineBlocks"]:
|
||||
"""
|
||||
Get the block(s) that would execute given the inputs.
|
||||
|
||||
Recursively resolves nested ConditionalPipelineBlocks until reaching either:
|
||||
- A leaf block (no sub_blocks) → returns single `ModularPipelineBlocks`
|
||||
- A `SequentialPipelineBlocks` → delegates to its `get_execution_blocks()` which returns
|
||||
a `SequentialPipelineBlocks` containing the resolved execution blocks
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
|
||||
Returns:
|
||||
- `ModularPipelineBlocks`: A leaf block or resolved `SequentialPipelineBlocks`
|
||||
- `None`: If this block would be skipped (no trigger matched and no default)
|
||||
"""
|
||||
trigger_kwargs = {name: kwargs.get(name) for name in self.block_trigger_inputs if name is not None}
|
||||
block_name = self.select_block(**trigger_kwargs)
|
||||
|
||||
if block_name is None:
|
||||
block_name = self.default_block_name
|
||||
|
||||
if block_name is None:
|
||||
return None
|
||||
|
||||
block = self.sub_blocks[block_name]
|
||||
|
||||
# Recursively resolve until we hit a leaf block or a SequentialPipelineBlocks
|
||||
if block.sub_blocks:
|
||||
return block.get_execution_blocks(**kwargs)
|
||||
|
||||
return block
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
@@ -755,11 +714,11 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||
)
|
||||
|
||||
if self.trigger_inputs:
|
||||
if self._get_trigger_inputs():
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
|
||||
header += f" Trigger Inputs: {sorted(self._get_trigger_inputs())}\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -826,24 +785,56 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
|
||||
This is a specialized version of `ConditionalPipelineBlocks` where:
|
||||
- Each block has one corresponding trigger input (1:1 mapping)
|
||||
- Block selection is automatic: the first block whose trigger input is present gets selected
|
||||
- `block_trigger_inputs` must have the same length as `block_names` and `block_classes`
|
||||
- Use `None` in `block_trigger_inputs` to specify the default block, i.e the block that will run if no trigger
|
||||
inputs are present
|
||||
|
||||
Attributes:
|
||||
block_classes:
|
||||
List of block classes to be used. Must have the same length as `block_names` and
|
||||
`block_trigger_inputs`.
|
||||
block_names:
|
||||
List of names for each block. Must have the same length as `block_classes` and `block_trigger_inputs`.
|
||||
block_trigger_inputs:
|
||||
List of input names where each element specifies the trigger input for the corresponding block. Use
|
||||
`None` to mark the default block.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyAutoBlock(AutoPipelineBlocks):
|
||||
block_classes = [InpaintEncoderBlock, ImageEncoderBlock, TextEncoderBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask_image", "image", None] # text2img is the default
|
||||
```
|
||||
|
||||
With this definition:
|
||||
- As long as `mask_image` is provided, "inpaint" block runs (regardless of `image` being provided or not)
|
||||
- If `mask_image` is not provided but `image` is provided, "img2img" block runs
|
||||
- Otherwise, "text2img" block runs (default, trigger is `None`)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.default_block_name is not None:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, do not set `default_block_name` for AutoPipelineBlocks. "
|
||||
f"Use `None` in `block_trigger_inputs` to specify the default block."
|
||||
)
|
||||
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_block_name(self) -> Optional[str]:
|
||||
"""Derive default_block_name from block_trigger_inputs (None entry)."""
|
||||
if None in self.block_trigger_inputs:
|
||||
idx = self.block_trigger_inputs.index(None)
|
||||
return self.block_names[idx]
|
||||
return None
|
||||
self.default_block_name = self.block_names[idx]
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""Select block based on which trigger input is present (not None)."""
|
||||
@@ -897,6 +888,29 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs.append(config)
|
||||
return expected_configs
|
||||
|
||||
@property
|
||||
def workflow_names(self):
|
||||
if self._workflow_map is None:
|
||||
raise NotImplementedError(
|
||||
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
return list(self._workflow_map.keys())
|
||||
|
||||
def get_workflow(self, workflow_name: str):
|
||||
if self._workflow_map is None:
|
||||
raise NotImplementedError(
|
||||
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
if workflow_name not in self._workflow_map:
|
||||
raise ValueError(f"Workflow {workflow_name} not found in {self.__class__.__name__}")
|
||||
|
||||
trigger_inputs = self._workflow_map[workflow_name]
|
||||
workflow_blocks = self.get_execution_blocks(**trigger_inputs)
|
||||
|
||||
return workflow_blocks
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(
|
||||
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
|
||||
@@ -992,7 +1006,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# filter out them here so they do not end up as intermediate_outputs
|
||||
if name not in inp_names:
|
||||
named_outputs.append((name, block.intermediate_outputs))
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
# YiYi TODO: I think we can remove the outputs property
|
||||
@@ -1016,6 +1030,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
raise
|
||||
return pipeline, state
|
||||
|
||||
# used for `trigger_inputs` property
|
||||
def _get_trigger_inputs(self):
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks.
|
||||
@@ -1039,89 +1054,50 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return fn_recursive_get_trigger(self.sub_blocks)
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def _traverse_trigger_blocks(self, active_inputs):
|
||||
def get_execution_blocks(self, **kwargs) -> "SequentialPipelineBlocks":
|
||||
"""
|
||||
Traverse blocks and select which ones would run given the active inputs.
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
|
||||
Args:
|
||||
active_inputs: Dict of input names to values that are "present"
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
|
||||
Returns:
|
||||
OrderedDict of block_name -> block that would execute
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
"""
|
||||
# Copy kwargs so we can add outputs as we traverse
|
||||
active_inputs = dict(kwargs)
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_inputs):
|
||||
result_blocks = OrderedDict()
|
||||
|
||||
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
||||
if isinstance(block, ConditionalPipelineBlocks):
|
||||
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
|
||||
selected_block_name = block.select_block(**trigger_kwargs)
|
||||
|
||||
if selected_block_name is None:
|
||||
selected_block_name = block.default_block_name
|
||||
|
||||
if selected_block_name is None:
|
||||
block = block.get_execution_blocks(**active_inputs)
|
||||
if block is None:
|
||||
return result_blocks
|
||||
|
||||
selected_block = block.sub_blocks[selected_block_name]
|
||||
|
||||
if selected_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
|
||||
else:
|
||||
result_blocks[block_name] = selected_block
|
||||
if hasattr(selected_block, "outputs"):
|
||||
for out in selected_block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
if block.sub_blocks:
|
||||
# Has sub_blocks (SequentialPipelineBlocks/ConditionalPipelineBlocks)
|
||||
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
nested_blocks = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
nested_blocks = {f"{block_name}.{k}": v for k, v in nested_blocks.items()}
|
||||
result_blocks.update(nested_blocks)
|
||||
else:
|
||||
# Leaf block: single ModularPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
result_blocks[block_name] = block
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
# Add outputs to active_inputs so subsequent blocks can use them as triggers
|
||||
if hasattr(block, "intermediate_outputs"):
|
||||
for out in block.intermediate_outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
all_blocks = OrderedDict()
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
all_blocks.update(blocks_to_update)
|
||||
return all_blocks
|
||||
nested_blocks = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
all_blocks.update(nested_blocks)
|
||||
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
Pass any inputs that would be non-None at runtime.
|
||||
|
||||
Returns:
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
|
||||
Example:
|
||||
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
|
||||
image=image)
|
||||
|
||||
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
|
||||
"""
|
||||
# Filter out None values
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(all_blocks)
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
@@ -1130,18 +1106,23 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||
)
|
||||
|
||||
if self.trigger_inputs:
|
||||
if self._workflow_map is None and self._get_trigger_inputs():
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self._get_trigger_inputs() if inp is not None]}\n"
|
||||
# Get first trigger input as example
|
||||
example_input = next(t for t in self.trigger_inputs if t is not None)
|
||||
example_input = next(t for t in self._get_trigger_inputs() if t is not None)
|
||||
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
description = self.description
|
||||
if self._workflow_map is not None:
|
||||
workflow_str = format_workflow(self._workflow_map)
|
||||
description = f"{self.description}\n\n{workflow_str}"
|
||||
|
||||
# Format description with proper indentation
|
||||
desc_lines = self.description.split("\n")
|
||||
desc_lines = description.split("\n")
|
||||
desc = []
|
||||
# First line with "Description:" label
|
||||
desc.append(f" Description: {desc_lines[0]}")
|
||||
@@ -1189,10 +1170,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
description = self.description
|
||||
if self._workflow_map is not None:
|
||||
workflow_str = format_workflow(self._workflow_map)
|
||||
description = f"{self.description}\n\n{workflow_str}"
|
||||
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
description=description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs,
|
||||
@@ -1325,7 +1311,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
for output in self.loop_intermediate_outputs:
|
||||
if output.name not in {output.name for output in combined_outputs}:
|
||||
combined_outputs.append(output)
|
||||
@@ -1588,7 +1574,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
else:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
@@ -1660,6 +1646,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
return self.default_blocks_name
|
||||
|
||||
@classmethod
|
||||
def _load_pipeline_config(
|
||||
cls,
|
||||
@@ -1755,8 +1744,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
|
||||
pipeline_class_name = map_fn(config_dict)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
@@ -860,6 +861,30 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
return "\n".join(formatted_configs)
|
||||
|
||||
|
||||
def format_workflow(workflow_map):
|
||||
"""Format a workflow map into a readable string representation.
|
||||
|
||||
Args:
|
||||
workflow_map: Dictionary mapping workflow names to trigger inputs
|
||||
|
||||
Returns:
|
||||
A formatted string representing all workflows
|
||||
"""
|
||||
if workflow_map is None:
|
||||
return ""
|
||||
|
||||
lines = ["Supported workflows:"]
|
||||
for workflow_name, trigger_inputs in workflow_map.items():
|
||||
required_inputs = [k for k, v in trigger_inputs.items() if v]
|
||||
if required_inputs:
|
||||
inputs_str = ", ".join(f"`{t}`" for t in required_inputs)
|
||||
lines.append(f" - `{workflow_name}`: requires {inputs_str}")
|
||||
else:
|
||||
lines.append(f" - `{workflow_name}`: default (no additional inputs required)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -914,3 +939,69 @@ def make_doc_string(
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
|
||||
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
|
||||
same input.
|
||||
|
||||
Args:
|
||||
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
|
||||
|
||||
Returns:
|
||||
List[InputParam]: Combined list of unique InputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> InputParam
|
||||
value_sources = {} # name -> block_name
|
||||
|
||||
for block_name, inputs in named_input_lists:
|
||||
for input_param in inputs:
|
||||
if input_param.name is None and input_param.kwargs_type is not None:
|
||||
input_name = "*_" + input_param.kwargs_type
|
||||
else:
|
||||
input_name = input_param.name
|
||||
if input_name in combined_dict:
|
||||
current_param = combined_dict[input_name]
|
||||
if (
|
||||
current_param.default is not None
|
||||
and input_param.default is not None
|
||||
and current_param.default != input_param.default
|
||||
):
|
||||
warnings.warn(
|
||||
f"Multiple different default values found for input '{input_name}': "
|
||||
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
||||
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
||||
)
|
||||
if current_param.default is None and input_param.default is not None:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
else:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
|
||||
"""
|
||||
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
|
||||
occurrence of each output name.
|
||||
|
||||
Args:
|
||||
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
|
||||
|
||||
Returns:
|
||||
List[OutputParam]: Combined list of unique OutputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> OutputParam
|
||||
|
||||
for block_name, outputs in named_output_lists:
|
||||
for output_param in outputs:
|
||||
if (output_param.name not in combined_dict) or (
|
||||
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
|
||||
):
|
||||
combined_dict[output_param.name] = output_param
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@@ -1113,10 +1113,14 @@ AUTO_BLOCKS = InsertableDict(
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
- for image-to-image generation, you need to provide `image`
|
||||
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
|
||||
- to run the controlnet workflow, you need to provide `control_image`
|
||||
- for text-to-image generation, all you need to provide is `prompt`
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `prompt`, `image`
|
||||
- `inpainting`: requires `prompt`, `mask_image`, `image`
|
||||
- `controlnet_text2image`: requires `prompt`, `control_image`
|
||||
- `controlnet_image2image`: requires `prompt`, `image`, `control_image`
|
||||
- `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
@@ -1197,15 +1201,24 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
# Workflow map defines the trigger conditions for each workflow.
|
||||
# How to define:
|
||||
# - Only include required inputs and trigger inputs (inputs that determine which blocks run)
|
||||
# - `True` means the workflow triggers when the input is not None (most common case)
|
||||
# - Use specific values (e.g., `{"strength": 0.5}`) if your `select_block` logic depends on the value
|
||||
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"prompt": True, "image": True},
|
||||
"inpainting": {"prompt": True, "mask_image": True, "image": True},
|
||||
"controlnet_text2image": {"prompt": True, "control_image": True},
|
||||
"controlnet_image2image": {"prompt": True, "image": True, "control_image": True},
|
||||
"controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
|
||||
@@ -21,16 +21,16 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
|
||||
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
|
||||
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
|
||||
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"WanModularPipeline",
|
||||
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["WanTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoImageEncoderStep",
|
||||
"WanAutoVaeImageEncoderStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -39,16 +39,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_wan import WanBlocks
|
||||
from .modular_blocks_wan22 import Wan22Blocks
|
||||
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
|
||||
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
|
||||
from .modular_pipeline import (
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
)
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_condition_latents"],
|
||||
image_latent_inputs: List[str] = ["first_frame_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
@@ -294,16 +294,20 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||
a single string or list of strings. Defaults to ["image_condition_latents"].
|
||||
a single string or list of strings. Defaults to ["first_frame_latents"].
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to [].
|
||||
|
||||
Examples:
|
||||
# Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to
|
||||
process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"]
|
||||
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
@@ -553,3 +557,81 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.first_last_frame_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanVaeDecoderStep(ModularPipelineBlocks):
|
||||
class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -89,10 +89,52 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_condition_latents",
|
||||
"first_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
|
||||
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
|
||||
block_state.dtype
|
||||
)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_last_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
@@ -105,7 +147,7 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat(
|
||||
[block_state.latents, block_state.image_condition_latents], dim=1
|
||||
[block_state.latents, block_state.first_last_frame_latents], dim=1
|
||||
).to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
@@ -542,3 +584,29 @@ class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanFLF2VLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanFLF2VLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports FLF2V tasks for wan2.1."
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -493,7 +493,7 @@ class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames", type_hint=int, default=81),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -564,51 +564,7 @@ class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -634,7 +590,7 @@ class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames", type_hint=int, default=81),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -711,49 +667,3 @@ class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.image_condition_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
474
src/diffusers/modular_pipelines/wan/modular_blocks.py
Normal file
474
src/diffusers/modular_pipelines/wan/modular_blocks.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
WanDenoiseStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeImageEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeImageEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# wan2.1
|
||||
# wan2.1: text2vid
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: image2video
|
||||
## image encoder
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: FLF2v
|
||||
|
||||
|
||||
## image encoder
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_last_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: auto blocks
|
||||
## image encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## denoise
|
||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanFLF2VCoreDenoiseStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["flf2v", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto pipeline blocks
|
||||
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# wan22
|
||||
# wan2.2: text2vid
|
||||
|
||||
|
||||
## denoise
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.2: image2video
|
||||
## denoise
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
Wan22AutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# presets for wan2.1 and wan2.2
|
||||
# YiYi Notes: should we move these to doc?
|
||||
# wan2.1
|
||||
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", WanDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
|
||||
("denoise", WanImage2VideoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLF2V_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
|
||||
("denoise", WanFLF2VDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# wan2.2 presets
|
||||
|
||||
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# presets all blocks (wan and wan22)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"wan2.1": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS,
|
||||
"flf2v": FLF2V_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
},
|
||||
"wan2.2": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
|
||||
"auto": AUTO_BLOCKS_WAN22,
|
||||
},
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
WanDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. BLOCKS (Wan2.1 text2video)
|
||||
# ====================
|
||||
|
||||
|
||||
class WanBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline blocks for Wan2.1.\n"
|
||||
+ "- `WanTextEncoderStep` is used to encode the text\n"
|
||||
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
|
||||
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
|
||||
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. BLOCKS (Wan2.2 text2video)
|
||||
# ====================
|
||||
|
||||
|
||||
class Wan22Blocks(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanImageResizeStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. BLOCKS (Wan2.2 Image2Video)
|
||||
# ====================
|
||||
|
||||
|
||||
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanImage2VideoVaeEncoderStep,
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for image-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
|
||||
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# ====================
|
||||
# 1. IMAGE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
# wan2.1 Auto Image Encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
model_name = "wan-i2v"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanImageResizeStep,
|
||||
WanImageCropResizeStep,
|
||||
WanFirstLastFrameVaeEncoderStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
# wan2.1 Auto Vae Encoder
|
||||
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V core denoise (support both I2V and FLF2V)
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. BLOCKS (Wan2.1 Image2Video)
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 Image2Video Auto Blocks
|
||||
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeEncoderStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for image-to-video using Wan.\n"
|
||||
+ "- for I2V workflow, all you need to provide is `image`"
|
||||
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
|
||||
)
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...utils import logging
|
||||
@@ -28,12 +30,19 @@ class WanModularPipeline(
|
||||
WanLoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for Wan2.1 text2video.
|
||||
A ModularPipeline for Wan.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanBlocks"
|
||||
default_blocks_name = "WanAutoBlocks"
|
||||
|
||||
# override the default_blocks_name in base class, which is just return self.default_blocks_name
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22AutoBlocks"
|
||||
else:
|
||||
return "WanAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
@@ -109,33 +118,3 @@ class WanModularPipeline(
|
||||
if hasattr(self, "scheduler") and self.scheduler is not None:
|
||||
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
return num_train_timesteps
|
||||
|
||||
|
||||
class WanImage2VideoModularPipeline(WanModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanImage2VideoAutoBlocks"
|
||||
|
||||
|
||||
class Wan22ModularPipeline(WanModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.2 text2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Wan22Blocks"
|
||||
|
||||
|
||||
class Wan22Image2VideoModularPipeline(Wan22ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.2 image2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Wan22Image2VideoBlocks"
|
||||
|
||||
@@ -246,7 +246,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan-i2v", WanImageToVideoPipeline),
|
||||
("wan", WanImageToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -47,21 +47,6 @@ class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -302,7 +287,7 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Blocks(metaclass=DummyObject):
|
||||
class Wan22AutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -317,82 +302,7 @@ class Wan22Blocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Image2VideoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Image2VideoModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImage2VideoAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImage2VideoModularPipeline(metaclass=DummyObject):
|
||||
class WanAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user