Compare commits

...

14 Commits

Author SHA1 Message Date
yiyixuxu
d34c4e8caf update the description of StableDiffusionXLDenoiseLoopWrapper 2025-06-20 07:38:21 +02:00
yiyixuxu
b46b7c8b31 add to method to modular loader, copied from DiffusionPipeline, not tested yet 2025-06-20 07:25:20 +02:00
yiyixuxu
fc9168f429 add block mappings to modular_diffusers.stable_diffusion_xl.__init__ 2025-06-20 07:24:14 +02:00
yiyixuxu
31a31ca1c5 rename modular_pipeline_block_mappings.py to modular_block_mapping 2025-06-20 07:23:14 +02:00
yiyixuxu
8423652b35 updatee modular_pipeline.from_pretrained, modular_repo ->pretrained_model_name_or_path 2025-06-19 05:30:18 +02:00
yiyixuxu
de631947cc up 2025-06-19 04:45:20 +02:00
yiyixuxu
58e9565719 update doc format for kwargs_type 2025-06-19 02:24:51 +02:00
yiyixuxu
cb6d5fed19 refator based on dhruv's feedbacks 2025-06-18 10:11:22 +02:00
yiyixuxu
f16e9c7807 add 2025-06-10 23:10:17 +02:00
yiyixuxu
87f63d424a modular node! 2025-05-22 11:50:36 +02:00
yiyixuxu
29de29f02c add node_utils 2025-05-21 22:31:10 +02:00
yiyixuxu
72e1b74638 solve merge conflict: manually add back the remote code change to modular_pipeline 2025-05-20 20:26:51 +02:00
yiyixuxu
3471f2fb75 merge part1 2025-05-20 18:53:04 +02:00
Dhruv Nair
808dff09cb [WIP] Modular Diffusers support custom code/pipeline blocks (#11539)
* update

* update
2025-05-20 15:12:51 +05:30
9 changed files with 1134 additions and 171 deletions

View File

@@ -264,6 +264,8 @@ else:
_import_structure["modular_pipelines"].extend(
[
"ModularLoader",
"ModularPipeline",
"ModularPipelineBlocks",
"ComponentSpec",
"ComponentsManager",
]
@@ -894,6 +896,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .modular_pipelines import (
ModularLoader,
ModularPipeline,
ModularPipelineBlocks,
ComponentSpec,
ComponentsManager,
)

View File

@@ -23,7 +23,8 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineMixin",
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
@@ -53,7 +54,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
BlockState,
LoopSequentialPipelineBlocks,
ModularLoader,
ModularPipelineMixin,
ModularPipelineBlocks,
ModularPipeline,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,

View File

@@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import traceback
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Union, Optional, Type
from typing import Any, Dict, List, Tuple, Union, Optional
from typing_extensions import Self
from copy import deepcopy
@@ -31,11 +34,10 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import (
is_accelerate_available,
is_accelerate_version,
logging,
PushToHubMixin,
)
from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple
from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -43,14 +45,12 @@ from .modular_pipeline_utils import (
OutputParam,
format_components,
format_configs,
format_input_params,
format_inputs_short,
format_intermediates_short,
format_output_params,
format_params,
make_doc_string,
)
from .components_manager import ComponentsManager
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from copy import deepcopy
if is_accelerate_available():
@@ -244,20 +244,73 @@ class BlockState:
return f"BlockState(\n{attributes}\n)"
class ModularPipelineMixin:
class ModularPipelineBlocks(ConfigMixin):
"""
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
"""
config_name = "config.json"
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
"""
create a mouldar loader, optionally accept modular_repo to load from hub.
"""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
hub_kwargs_names = [
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"subfolder",
"token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
# Import components loader (it is model-specific class)
loader_class_name = MODULAR_LOADER_MAPPING[self.model_name]
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError("TODO")
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
block_cls = get_class_from_dynamic_module(
pretrained_model_name_or_path,
module_file=module_file,
class_name=class_name,
is_modular=True,
**hub_kwargs,
**kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
block_kwargs = {
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
}
return block_cls(**block_kwargs)
def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
"""
create a ModularLoader, optionally accept modular_repo to load from hub.
"""
loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
diffusers_module = importlib.import_module("diffusers")
loader_class = getattr(diffusers_module, loader_class_name)
@@ -267,105 +320,20 @@ class ModularPipelineMixin:
# Create the loader with the updated specs
specs = component_specs + config_specs
self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection)
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
return modular_pipeline
@property
def default_call_parameters(self) -> Dict[str, Any]:
params = {}
for input_param in self.inputs:
params[input_param.name] = input_param.default
return params
def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
"""
if state is None:
state = PipelineState()
if not hasattr(self, "loader"):
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
self.loader = None
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
for expected_input_param in self.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediates_inputs:
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.add_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.add_input(name, default, kwargs_type)
for expected_intermediate_param in self.intermediates_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
pipeline, state = self(self.loader, state)
except Exception:
error_msg = f"Error in block: ({self.__class__.__name__}):\n"
logger.error(error_msg)
raise
if output is None:
return state
elif isinstance(output, str):
return state.get_intermediate(output)
elif isinstance(output, (list, tuple)):
return state.get_intermediates(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")
@torch.compiler.disable
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
class PipelineBlock(ModularPipelineMixin):
class PipelineBlock(ModularPipelineBlocks):
model_name = None
@property
def description(self) -> str:
"""Description of the block. Must be implemented by subclasses."""
raise NotImplementedError("description method must be implemented in subclasses")
# raise NotImplementedError("description method must be implemented in subclasses")
return "TODO: add a description"
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -624,7 +592,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
return list(combined_dict.values())
class AutoPipelineBlocks(ModularPipelineMixin):
class AutoPipelineBlocks(ModularPipelineBlocks):
"""
A class that automatically selects a block to run based on the inputs.
@@ -692,6 +660,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
@property
def required_inputs(self) -> List[str]:
if None not in self.block_trigger_inputs:
return []
first_block = next(iter(self.blocks.values()))
required_by_all = set(getattr(first_block, "required_inputs", set()))
@@ -706,6 +676,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
if None not in self.block_trigger_inputs:
return []
first_block = next(iter(self.blocks.values()))
required_by_all = set(getattr(first_block, "required_intermediates_inputs", set()))
@@ -909,7 +881,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
expected_configs=self.expected_configs
)
class SequentialPipelineBlocks(ModularPipelineMixin):
class SequentialPipelineBlocks(ModularPipelineBlocks):
"""
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
"""
@@ -949,15 +922,24 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
Args:
blocks_dict: Dictionary mapping block names to block instances
blocks_dict: Dictionary mapping block names to block classes or instances
Returns:
A new SequentialPipelineBlocks instance
"""
instance = cls()
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
instance.block_names = list(blocks_dict.keys())
instance.blocks = blocks_dict
# Create instances if classes are provided
blocks = {}
for name, block in blocks_dict.items():
if inspect.isclass(block):
blocks[name] = block()
else:
blocks[name] = block
instance.block_classes = [block.__class__ for block in blocks.values()]
instance.block_names = list(blocks.keys())
instance.blocks = blocks
return instance
def __init__(self):
@@ -1270,7 +1252,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
)
#YiYi TODO: __repr__
class LoopSequentialPipelineBlocks(ModularPipelineMixin):
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""
A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
"""
@@ -1634,7 +1616,24 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
return result
@torch.compiler.disable
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
# YiYi TODO:
@@ -1750,7 +1749,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
"""
Initialize the loader with a list of component specs and config specs.
"""
@@ -1764,8 +1763,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
}
# update component_specs and config_specs from modular_repo
if modular_repo is not None:
config_dict = self.load_config(modular_repo, **kwargs)
if pretrained_model_name_or_path is not None:
config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
for name, value in config_dict.items():
# only update component_spec for from_pretrained components
@@ -1829,19 +1828,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
return torch.device(module._hf_hook.execution_device)
return self.device
@property
def device(self) -> torch.device:
r"""
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)]
for module in modules:
return module.device
return torch.device("cpu")
@property
def dtype(self) -> torch.dtype:
@@ -2012,9 +1998,195 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
# Register all components at once
self.register_components(**components_to_register)
# YiYi TODO: should support to method
def to(self, *args, **kwargs):
pass
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
def to(self, *args, **kwargs) -> Self:
r"""
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
<Tip>
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
</Tip>
Here are the ways to call `to`:
- `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
- `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
- `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
Arguments:
dtype (`torch.dtype`, *optional*):
Returns a pipeline with the specified
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
device (`torch.Device`, *optional*):
Returns a pipeline with the specified
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
silence_dtype_warnings (`str`, *optional*, defaults to `False`):
Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
Returns:
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
"""
dtype = kwargs.pop("dtype", None)
device = kwargs.pop("device", None)
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
dtype_arg = None
device_arg = None
if len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype_arg = args[0]
else:
device_arg = torch.device(args[0]) if args[0] is not None else None
elif len(args) == 2:
if isinstance(args[0], torch.dtype):
raise ValueError(
"When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
)
device_arg = torch.device(args[0]) if args[0] is not None else None
dtype_arg = args[1]
elif len(args) > 2:
raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")
if dtype is not None and dtype_arg is not None:
raise ValueError(
"You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
)
dtype = dtype or dtype_arg
if device is not None and device_arg is not None:
raise ValueError(
"You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
)
device = device or device_arg
device_type = torch.device(device).type if device is not None else None
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
if is_loaded_in_8bit_bnb:
return False
return hasattr(module, "_hf_hook") and (
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
or hasattr(module._hf_hook, "hooks")
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
)
def module_is_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
return False
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
)
if device_type in ["cuda", "xpu"]:
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
# PR: https://github.com/huggingface/accelerate/pull/3223/
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
# Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
os.environ["PT_HPU_GPU_MIGRATION"] = "1"
logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
import habana_frameworks.torch # noqa: F401
# HPU hardware check
if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)
if is_loaded_in_8bit_bnb and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
# components can be from outside diffusers too, but still have group offloading enabled.
if (
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
and device is not None
):
logger.warning(
f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
)
# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
module.to(device, dtype)
if (
module.dtype == torch.float16
and str(device) in ["cpu"]
and not silence_dtype_warnings
and not is_offloaded
):
logger.warning(
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
return self
# YiYi TODO:
# 1. should support save some components too! currently only modular_model_index.json is saved
@@ -2137,4 +2309,111 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
name=name,
type_hint=type_hint,
**spec_dict,
)
)
class ModularPipeline:
"""
Base class for all Modular pipelines.
Args:
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
loader: ModularLoader, the loader to be used in the pipeline
"""
def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader):
self.blocks = blocks
self.loader = loader
def __repr__(self):
blocks_class = self.blocks.__class__.__name__
loader_class = self.loader.__class__.__name__
return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})"
@property
def default_call_parameters(self) -> Dict[str, Any]:
params = {}
for input_param in self.blocks.inputs:
params[input_param.name] = input_param.default
return params
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
"""
if state is None:
state = PipelineState()
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediates_inputs:
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.add_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.add_input(name, default, kwargs_type)
for expected_intermediate_param in self.blocks.intermediates_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
pipeline, state = self.blocks(self.loader, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise
if output is None:
return state
elif isinstance(output, str):
return state.get_intermediate(output)
elif isinstance(output, (list, tuple)):
return state.get_intermediates(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")
def load_components(self, component_names: Optional[List[str]] = None, **kwargs):
self.loader.load(component_names=component_names, **kwargs)
def update_components(self, **kwargs):
self.loader.update(**kwargs)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs)
return pipeline
def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs):
self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@property
def doc(self):
return self.blocks.doc

View File

@@ -19,11 +19,30 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
from collections import OrderedDict
if is_torch_available():
import torch
class InsertableOrderedDict(OrderedDict):
def insert(self, key, value, index):
items = list(self.items())
# Remove key if it already exists to avoid duplicates
items = [(k, v) for k, v in items if k != key]
# Insert at the specified index
items.insert(index, (key, value))
# Clear and update self
self.clear()
self.update(items)
# Return self for method chaining
return self
# YiYi TODO:
# 1. validate the dataclass fields
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
@@ -246,7 +265,7 @@ class InputParam:
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@@ -258,7 +277,7 @@ class OutputParam:
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
@@ -402,7 +421,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
param_str = f"{param_indent}{param.name} (`{type_str}`"
# YiYi Notes: remove this line if we remove kwargs_type
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
param_str = f"{param_indent}{name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):

View File

@@ -0,0 +1,519 @@
from ..configuration_utils import ConfigMixin
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
from .modular_pipeline_utils import InputParam, OutputParam
from ..image_processor import PipelineImageInput
from pathlib import Path
import json
import os
from typing import Union, List, Optional, Tuple
import torch
import PIL
import numpy as np
import logging
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS ={
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None
e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.blocks.keys())[-1]
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
else:
outputs = self.blocks.intermediates_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components, collection=None):
self.blocks.setup_loader(component_manager=components, collection=collection)
self._components_manager = components
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {
"module": self.mellon_config,
"name_mapping": self.name_mapping
}
# Save the config to file
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.blocks.loader.update(**params_components)
output = self.blocks.run(**params_run, output=return_output_names)
return output

View File

@@ -25,6 +25,7 @@ else:
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
_import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"]
_import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"]
_import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -37,6 +38,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_loader import StableDiffusionXLModularLoader
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
else:
import sys

View File

@@ -54,7 +54,7 @@ class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock):
@property
def description(self) -> str:
return "step within the denoising loop that prepare the latent input for the denoiser"
return "step within the denoising loop that prepare the latent input for the denoiser. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
@property
@@ -89,7 +89,7 @@ class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock):
@property
def description(self) -> str:
return "step within the denoising loop that prepare the latent input for the denoiser"
return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
@property
@@ -165,7 +165,7 @@ class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock):
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance"
"Step within the denoising loop that denoise the latents with guidance. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
)
@property
@@ -269,7 +269,7 @@ class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock):
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
return "step within the denoising loop that denoise the latents with guidance (with controlnet). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -458,7 +458,7 @@ class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock):
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
return "step within the denoising loop that update the latents. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -521,7 +521,7 @@ class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock):
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
return "step within the denoising loop that update the latents (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -622,7 +622,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
@property
def description(self) -> str:
return (
"Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process"
"Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes"
)
@property
@@ -683,21 +683,60 @@ class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
" - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n"
" - `StableDiffusionXLDenoiseLoopDenoiser`\n"
" - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n"
)
# control_cond
class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
" - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n"
" - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n"
)
# mask
class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents(for inpainting task only). \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
" - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n"
" - `StableDiffusionXLDenoiseLoopDenoiser`\n"
" - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n"
)
# control_cond + mask
class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
" - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n"
" - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n"
)
# all task without controlnet
@@ -706,18 +745,45 @@ class StableDiffusionXLDenoiseStep(AutoPipelineBlocks):
block_names = ["inpaint_denoise", "denoise"]
block_trigger_inputs = ["mask", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks."
" - `StableDiffusionXLDenoiseStep` (denoise) is used when no mask is provided."
" - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided."
)
# all task with controlnet
class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop]
block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
block_trigger_inputs = ["mask", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents with controlnet. "
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks."
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when no mask is provided."
" - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
)
# all task with or without controlnet
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep]
block_names = ["controlnet_denoise", "denoise"]
block_trigger_inputs = ["controlnet_cond", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
" - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)."
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)."
)

View File

@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from ..modular_pipeline_utils import InsertableOrderedDict
# Import all the necessary block classes
from .denoise import (
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLDenoiseStep,
StableDiffusionXLControlNetDenoiseStep
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDenoiseLoop,
StableDiffusionXLInpaintDenoiseLoop
)
from .before_denoise import (
StableDiffusionXLAutoBeforeDenoiseStep,
@@ -50,56 +51,53 @@ from .decoders import (
# YiYi notes: comment out for now, work on this later
# block mapping
TEXT2IMAGE_BLOCKS = OrderedDict([
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("denoise", StableDiffusionXLDenoiseLoop),
("decode", StableDiffusionXLDecodeStep)
])
IMAGE2IMAGE_BLOCKS = OrderedDict([
IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("denoise", StableDiffusionXLDenoiseLoop),
("decode", StableDiffusionXLDecodeStep)
])
INPAINT_BLOCKS = OrderedDict([
INPAINT_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("denoise", StableDiffusionXLInpaintDenoiseLoop),
("decode", StableDiffusionXLInpaintDecodeStep)
])
CONTROLNET_BLOCKS = OrderedDict([
CONTROLNET_BLOCKS = InsertableOrderedDict([
("controlnet_input", StableDiffusionXLControlNetInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
CONTROLNET_UNION_BLOCKS = OrderedDict([
CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([
("controlnet_input", StableDiffusionXLControlNetUnionInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
IP_ADAPTER_BLOCKS = OrderedDict([
IP_ADAPTER_BLOCKS = InsertableOrderedDict([
("ip_adapter", StableDiffusionXLIPAdapterStep),
])
AUTO_BLOCKS = OrderedDict([
AUTO_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
@@ -108,11 +106,6 @@ AUTO_BLOCKS = OrderedDict([
("decode", StableDiffusionXLAutoDecodeStep)
])
AUTO_CORE_BLOCKS = OrderedDict([
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
])
SDXL_SUPPORTED_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,

View File

@@ -15,13 +15,16 @@
"""Utilities to dynamically load objects from the Hub."""
import importlib
import signal
import inspect
import json
import os
import re
import shutil
import sys
import threading
from pathlib import Path
from types import ModuleType
from typing import Dict, Optional, Union
from urllib import request
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions():
@@ -154,15 +159,87 @@ def check_imports(filename):
return get_relative_imports(filename)
def get_class_in_module(class_name, module_path):
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
if trust_remote_code is None:
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
def get_class_in_module(class_name, module_path, force_reload=False):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
module_spec.loader.exec_module(module)
if class_name is None:
return find_pipeline_class(module)
return getattr(module, class_name)
@@ -454,4 +531,4 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
return get_class_in_module(class_name, final_module)