mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-24 03:40:37 +08:00
Compare commits
14 Commits
modular-re
...
modular-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d34c4e8caf | ||
|
|
b46b7c8b31 | ||
|
|
fc9168f429 | ||
|
|
31a31ca1c5 | ||
|
|
8423652b35 | ||
|
|
de631947cc | ||
|
|
58e9565719 | ||
|
|
cb6d5fed19 | ||
|
|
f16e9c7807 | ||
|
|
87f63d424a | ||
|
|
29de29f02c | ||
|
|
72e1b74638 | ||
|
|
3471f2fb75 | ||
|
|
808dff09cb |
@@ -264,6 +264,8 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"ModularLoader",
|
||||
"ModularPipeline",
|
||||
"ModularPipelineBlocks",
|
||||
"ComponentSpec",
|
||||
"ComponentsManager",
|
||||
]
|
||||
@@ -894,6 +896,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ModularLoader,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
ComponentsManager,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,8 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineMixin",
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
@@ -53,7 +54,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularLoader,
|
||||
ModularPipelineMixin,
|
||||
ModularPipelineBlocks,
|
||||
ModularPipeline,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
|
||||
@@ -11,12 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
|
||||
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional, Type
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
from typing_extensions import Self
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
@@ -31,11 +34,10 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
PushToHubMixin,
|
||||
)
|
||||
from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple
|
||||
from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple
|
||||
from .modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -43,14 +45,12 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_input_params,
|
||||
format_inputs_short,
|
||||
format_intermediates_short,
|
||||
format_output_params,
|
||||
format_params,
|
||||
make_doc_string,
|
||||
)
|
||||
from .components_manager import ComponentsManager
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
from copy import deepcopy
|
||||
if is_accelerate_available():
|
||||
@@ -244,20 +244,73 @@ class BlockState:
|
||||
return f"BlockState(\n{attributes}\n)"
|
||||
|
||||
|
||||
|
||||
class ModularPipelineMixin:
|
||||
class ModularPipelineBlocks(ConfigMixin):
|
||||
"""
|
||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
create a mouldar loader, optionally accept modular_repo to load from hub.
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
|
||||
# Import components loader (it is model-specific class)
|
||||
loader_class_name = MODULAR_LOADER_MAPPING[self.model_name]
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError("TODO")
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
block_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
is_modular=True,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
|
||||
block_kwargs = {
|
||||
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
|
||||
}
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
|
||||
def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
create a ModularLoader, optionally accept modular_repo to load from hub.
|
||||
"""
|
||||
loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
loader_class = getattr(diffusers_module, loader_class_name)
|
||||
|
||||
@@ -267,105 +320,20 @@ class ModularPipelineMixin:
|
||||
# Create the loader with the updated specs
|
||||
specs = component_specs + config_specs
|
||||
|
||||
self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
|
||||
loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection)
|
||||
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
|
||||
return modular_pipeline
|
||||
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for input_param in self.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
if not hasattr(self, "loader"):
|
||||
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
|
||||
self.loader = None
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
|
||||
for expected_input_param in self.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.add_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.intermediates_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
pipeline, state = self(self.loader, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
@torch.compiler.disable
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineMixin):
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
|
||||
model_name = None
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("description method must be implemented in subclasses")
|
||||
# raise NotImplementedError("description method must be implemented in subclasses")
|
||||
return "TODO: add a description"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -624,7 +592,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that automatically selects a block to run based on the inputs.
|
||||
|
||||
@@ -692,6 +660,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
return []
|
||||
first_block = next(iter(self.blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_inputs", set()))
|
||||
|
||||
@@ -706,6 +676,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
return []
|
||||
first_block = next(iter(self.blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_intermediates_inputs", set()))
|
||||
|
||||
@@ -909,7 +881,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
expected_configs=self.expected_configs
|
||||
)
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
|
||||
"""
|
||||
@@ -949,15 +922,24 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
|
||||
|
||||
Args:
|
||||
blocks_dict: Dictionary mapping block names to block instances
|
||||
blocks_dict: Dictionary mapping block names to block classes or instances
|
||||
|
||||
Returns:
|
||||
A new SequentialPipelineBlocks instance
|
||||
"""
|
||||
instance = cls()
|
||||
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
|
||||
instance.block_names = list(blocks_dict.keys())
|
||||
instance.blocks = blocks_dict
|
||||
|
||||
# Create instances if classes are provided
|
||||
blocks = {}
|
||||
for name, block in blocks_dict.items():
|
||||
if inspect.isclass(block):
|
||||
blocks[name] = block()
|
||||
else:
|
||||
blocks[name] = block
|
||||
|
||||
instance.block_classes = [block.__class__ for block in blocks.values()]
|
||||
instance.block_names = list(blocks.keys())
|
||||
instance.blocks = blocks
|
||||
return instance
|
||||
|
||||
def __init__(self):
|
||||
@@ -1270,7 +1252,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
)
|
||||
|
||||
#YiYi TODO: __repr__
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
|
||||
"""
|
||||
@@ -1634,7 +1616,24 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
return result
|
||||
|
||||
@torch.compiler.disable
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
|
||||
# YiYi TODO:
|
||||
@@ -1750,7 +1749,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
|
||||
|
||||
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
|
||||
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Initialize the loader with a list of component specs and config specs.
|
||||
"""
|
||||
@@ -1764,8 +1763,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
}
|
||||
|
||||
# update component_specs and config_specs from modular_repo
|
||||
if modular_repo is not None:
|
||||
config_dict = self.load_config(modular_repo, **kwargs)
|
||||
if pretrained_model_name_or_path is not None:
|
||||
config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for name, value in config_dict.items():
|
||||
# only update component_spec for from_pretrained components
|
||||
@@ -1829,19 +1828,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
r"""
|
||||
Returns:
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
|
||||
modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module in modules:
|
||||
return module.device
|
||||
|
||||
return torch.device("cpu")
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@@ -2012,9 +1998,195 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
# Register all components at once
|
||||
self.register_components(**components_to_register)
|
||||
|
||||
# YiYi TODO: should support to method
|
||||
def to(self, *args, **kwargs):
|
||||
pass
|
||||
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
r"""
|
||||
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
|
||||
arguments of `self.to(*args, **kwargs).`
|
||||
|
||||
<Tip>
|
||||
|
||||
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
|
||||
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
Here are the ways to call `to`:
|
||||
|
||||
- `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
|
||||
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
|
||||
- `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
|
||||
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
|
||||
- `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
|
||||
specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
|
||||
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
|
||||
|
||||
Arguments:
|
||||
dtype (`torch.dtype`, *optional*):
|
||||
Returns a pipeline with the specified
|
||||
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
|
||||
device (`torch.Device`, *optional*):
|
||||
Returns a pipeline with the specified
|
||||
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
|
||||
silence_dtype_warnings (`str`, *optional*, defaults to `False`):
|
||||
Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
|
||||
|
||||
Returns:
|
||||
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
device = kwargs.pop("device", None)
|
||||
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
|
||||
|
||||
dtype_arg = None
|
||||
device_arg = None
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
dtype_arg = args[0]
|
||||
else:
|
||||
device_arg = torch.device(args[0]) if args[0] is not None else None
|
||||
elif len(args) == 2:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
raise ValueError(
|
||||
"When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
|
||||
)
|
||||
device_arg = torch.device(args[0]) if args[0] is not None else None
|
||||
dtype_arg = args[1]
|
||||
elif len(args) > 2:
|
||||
raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")
|
||||
|
||||
if dtype is not None and dtype_arg is not None:
|
||||
raise ValueError(
|
||||
"You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
|
||||
)
|
||||
|
||||
dtype = dtype or dtype_arg
|
||||
|
||||
if device is not None and device_arg is not None:
|
||||
raise ValueError(
|
||||
"You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
|
||||
)
|
||||
|
||||
device = device or device_arg
|
||||
device_type = torch.device(device).type if device is not None else None
|
||||
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
def module_is_sequentially_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
|
||||
if is_loaded_in_8bit_bnb:
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
or hasattr(module._hf_hook, "hooks")
|
||||
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
|
||||
)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
|
||||
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
)
|
||||
|
||||
if device_type in ["cuda", "xpu"]:
|
||||
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
# PR: https://github.com/huggingface/accelerate/pull/3223/
|
||||
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
|
||||
raise ValueError(
|
||||
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
# Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
|
||||
if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
|
||||
os.environ["PT_HPU_GPU_MIGRATION"] = "1"
|
||||
logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
|
||||
|
||||
import habana_frameworks.torch # noqa: F401
|
||||
|
||||
# HPU hardware check
|
||||
if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
|
||||
raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
|
||||
|
||||
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
|
||||
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
|
||||
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for module in modules:
|
||||
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
|
||||
|
||||
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit_bnb and device is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
||||
)
|
||||
|
||||
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
|
||||
# components can be from outside diffusers too, but still have group offloading enabled.
|
||||
if (
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
|
||||
and device is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
|
||||
)
|
||||
|
||||
# This can happen for `transformer` models. CPU placement was added in
|
||||
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
||||
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
||||
module.to(device=device)
|
||||
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
|
||||
module.to(device, dtype)
|
||||
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
return self
|
||||
|
||||
# YiYi TODO:
|
||||
# 1. should support save some components too! currently only modular_model_index.json is saved
|
||||
@@ -2137,4 +2309,111 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
name=name,
|
||||
type_hint=type_hint,
|
||||
**spec_dict,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ModularPipeline:
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
|
||||
Args:
|
||||
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
|
||||
loader: ModularLoader, the loader to be used in the pipeline
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader):
|
||||
self.blocks = blocks
|
||||
self.loader = loader
|
||||
|
||||
def __repr__(self):
|
||||
blocks_class = self.blocks.__class__.__name__
|
||||
loader_class = self.loader.__class__.__name__
|
||||
return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})"
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for input_param in self.blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.add_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.blocks.intermediates_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
pipeline, state = self.blocks(self.loader, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
|
||||
def load_components(self, component_names: Optional[List[str]] = None, **kwargs):
|
||||
self.loader.load(component_names=component_names, **kwargs)
|
||||
|
||||
def update_components(self, **kwargs):
|
||||
self.loader.update(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
||||
pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs)
|
||||
return pipeline
|
||||
|
||||
def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs):
|
||||
self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return self.blocks.doc
|
||||
@@ -19,11 +19,30 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
|
||||
|
||||
from ..utils.import_utils import is_torch_available
|
||||
from ..configuration_utils import FrozenDict, ConfigMixin
|
||||
from collections import OrderedDict
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class InsertableOrderedDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
items = list(self.items())
|
||||
|
||||
# Remove key if it already exists to avoid duplicates
|
||||
items = [(k, v) for k, v in items if k != key]
|
||||
|
||||
# Insert at the specified index
|
||||
items.insert(index, (key, value))
|
||||
|
||||
# Clear and update self
|
||||
self.clear()
|
||||
self.update(items)
|
||||
|
||||
# Return self for method chaining
|
||||
return self
|
||||
|
||||
|
||||
# YiYi TODO:
|
||||
# 1. validate the dataclass fields
|
||||
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
|
||||
@@ -246,7 +265,7 @@ class InputParam:
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it
|
||||
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
@@ -258,7 +277,7 @@ class OutputParam:
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
@@ -402,7 +421,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
for param in params:
|
||||
# Format parameter name and type
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
param_str = f"{param_indent}{param.name} (`{type_str}`"
|
||||
# YiYi Notes: remove this line if we remove kwargs_type
|
||||
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
|
||||
param_str = f"{param_indent}{name} (`{type_str}`"
|
||||
|
||||
# Add optional tag and default value if parameter is an InputParam and optional
|
||||
if hasattr(param, "required"):
|
||||
|
||||
519
src/diffusers/modular_pipelines/node_utils.py
Normal file
519
src/diffusers/modular_pipelines/node_utils.py
Normal file
@@ -0,0 +1,519 @@
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam, OutputParam
|
||||
from ..image_processor import PipelineImageInput
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
|
||||
from typing import Union, List, Optional, Tuple
|
||||
import torch
|
||||
import PIL
|
||||
import numpy as np
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||
SDXL_INPUTS_SCHEMA = {
|
||||
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
|
||||
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
|
||||
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
|
||||
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
|
||||
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
|
||||
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
|
||||
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
|
||||
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
|
||||
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
|
||||
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
|
||||
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
|
||||
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
|
||||
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
|
||||
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
|
||||
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
|
||||
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
|
||||
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
|
||||
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
|
||||
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
|
||||
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
|
||||
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
|
||||
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
|
||||
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
|
||||
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
|
||||
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
|
||||
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
|
||||
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
|
||||
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
|
||||
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
|
||||
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
|
||||
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
|
||||
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
|
||||
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
|
||||
}
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
|
||||
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
|
||||
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
|
||||
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
|
||||
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
|
||||
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
|
||||
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
|
||||
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
|
||||
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
|
||||
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
|
||||
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
|
||||
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
|
||||
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
|
||||
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
|
||||
}
|
||||
|
||||
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
||||
|
||||
|
||||
DEFAULT_PARAM_MAPS = {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"default": "a bear sitting in a chair drinking a milkshake",
|
||||
"display": "textarea",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"label": "Negative Prompt",
|
||||
"type": "string",
|
||||
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
"display": "textarea",
|
||||
},
|
||||
|
||||
"num_inference_steps": {
|
||||
"label": "Steps",
|
||||
"type": "int",
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 1000,
|
||||
},
|
||||
"seed": {
|
||||
"label": "Seed",
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"display": "random",
|
||||
},
|
||||
"width": {
|
||||
"label": "Width",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"height": {
|
||||
"label": "Height",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"images": {
|
||||
"label": "Images",
|
||||
"type": "image",
|
||||
"display": "output",
|
||||
},
|
||||
"image": {
|
||||
"label": "Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TYPE_MAPS ={
|
||||
"int": {
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
},
|
||||
"float": {
|
||||
"type": "float",
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
},
|
||||
"str": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
"bool": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
},
|
||||
"image": {
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
||||
DEFAULT_CATEGORY = "Modular Diffusers"
|
||||
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
||||
DEFAULT_PARAMS_GROUPS_KEYS = {
|
||||
"text_encoders": ["text_encoder", "tokenizer"],
|
||||
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
||||
"prompt_embeddings": ["prompt_embeds"],
|
||||
}
|
||||
|
||||
|
||||
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||
"""
|
||||
Get the group name for a given parameter name, if not part of a group, return None
|
||||
e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
||||
"""
|
||||
if name is None:
|
||||
return None
|
||||
for group_name, group_keys in group_params_keys.items():
|
||||
for group_key in group_keys:
|
||||
if group_key in name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
class ModularNode(ConfigMixin):
|
||||
|
||||
config_name = "node_config.json"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
||||
return cls(blocks, **kwargs)
|
||||
|
||||
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||
self.blocks = blocks
|
||||
|
||||
if label is None:
|
||||
label = self.blocks.__class__.__name__
|
||||
# blocks param name -> mellon param name
|
||||
self.name_mapping = {}
|
||||
|
||||
input_params = {}
|
||||
# pass or create a default param dict for each input
|
||||
# e.g. for prompt,
|
||||
# prompt = {
|
||||
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
|
||||
for inp in inputs:
|
||||
param = kwargs.pop(inp.name, None)
|
||||
if param:
|
||||
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
||||
input_params[inp.name] = param
|
||||
mellon_name = param.pop("name", inp.name)
|
||||
if mellon_name != inp.name:
|
||||
self.name_mapping[inp.name] = mellon_name
|
||||
continue
|
||||
|
||||
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
continue
|
||||
|
||||
if inp.name in DEFAULT_PARAM_MAPS:
|
||||
# first check if it's in the default param map, if so, directly use that
|
||||
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||
elif get_group_name(inp.name):
|
||||
param = get_group_name(inp.name)
|
||||
if inp.name not in self.name_mapping:
|
||||
self.name_mapping[inp.name] = param
|
||||
else:
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# 1. use the type hint to determine the type
|
||||
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||
if inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).lower()
|
||||
else:
|
||||
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
||||
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
||||
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
||||
if type_key in type_str:
|
||||
param = type_param.copy()
|
||||
param["label"] = inp.name
|
||||
param["display"] = "input"
|
||||
break
|
||||
else:
|
||||
param = inp.name
|
||||
# add the param dict to the inp_params dict
|
||||
input_params[inp.name] = param
|
||||
|
||||
|
||||
component_params = {}
|
||||
for comp in self.blocks.expected_components:
|
||||
param = kwargs.pop(comp.name, None)
|
||||
if param:
|
||||
component_params[comp.name] = param
|
||||
mellon_name = param.pop("name", comp.name)
|
||||
if mellon_name != comp.name:
|
||||
self.name_mapping[comp.name] = mellon_name
|
||||
continue
|
||||
|
||||
to_exclude = False
|
||||
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
||||
if exclude_key in comp.name:
|
||||
to_exclude = True
|
||||
break
|
||||
if to_exclude:
|
||||
continue
|
||||
|
||||
if get_group_name(comp.name):
|
||||
param = get_group_name(comp.name)
|
||||
if comp.name not in self.name_mapping:
|
||||
self.name_mapping[comp.name] = param
|
||||
elif comp.name in DEFAULT_MODEL_KEYS:
|
||||
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
||||
else:
|
||||
param = comp.name
|
||||
# add the param dict to the model_params dict
|
||||
component_params[comp.name] = param
|
||||
|
||||
output_params = {}
|
||||
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||
last_block_name = list(self.blocks.blocks.keys())[-1]
|
||||
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediates_outputs
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
if param:
|
||||
output_params[out.name] = param
|
||||
mellon_name = param.pop("name", out.name)
|
||||
if mellon_name != out.name:
|
||||
self.name_mapping[out.name] = mellon_name
|
||||
continue
|
||||
|
||||
if out.name in DEFAULT_PARAM_MAPS:
|
||||
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
||||
param["display"] = "output"
|
||||
else:
|
||||
group_name = get_group_name(out.name)
|
||||
if group_name:
|
||||
param = group_name
|
||||
if out.name not in self.name_mapping:
|
||||
self.name_mapping[out.name] = param
|
||||
else:
|
||||
param = out.name
|
||||
# add the param dict to the outputs dict
|
||||
output_params[out.name] = param
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused kwargs: {kwargs}")
|
||||
|
||||
register_dict = {
|
||||
"category": category,
|
||||
"label": label,
|
||||
"input_params": input_params,
|
||||
"component_params": component_params,
|
||||
"output_params": output_params,
|
||||
"name_mapping": self.name_mapping,
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def setup(self, components, collection=None):
|
||||
self.blocks.setup_loader(component_manager=components, collection=collection)
|
||||
self._components_manager = components
|
||||
|
||||
@property
|
||||
def mellon_config(self):
|
||||
return self._convert_to_mellon_config()
|
||||
|
||||
def _convert_to_mellon_config(self):
|
||||
|
||||
node = {}
|
||||
node["label"] = self.config.label
|
||||
node["category"] = self.config.category
|
||||
|
||||
node_param = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
if inp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[inp_name]
|
||||
else:
|
||||
mellon_name = inp_name
|
||||
if isinstance(inp_param, str):
|
||||
param = {
|
||||
"label": inp_param,
|
||||
"type": inp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = inp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||
|
||||
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
if comp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[comp_name]
|
||||
else:
|
||||
mellon_name = comp_name
|
||||
if isinstance(comp_param, str):
|
||||
param = {
|
||||
"label": comp_param,
|
||||
"type": comp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = comp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||
|
||||
|
||||
for out_name, out_param in self.config.output_params.items():
|
||||
if out_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[out_name]
|
||||
else:
|
||||
mellon_name = out_name
|
||||
if isinstance(out_param, str):
|
||||
param = {
|
||||
"label": out_param,
|
||||
"type": out_param,
|
||||
"display": "output",
|
||||
}
|
||||
else:
|
||||
param = out_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
||||
node["params"] = node_param
|
||||
return node
|
||||
|
||||
def save_mellon_config(self, file_path):
|
||||
"""
|
||||
Save the Mellon configuration to a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path where the JSON file will be saved
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved config file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# Create a combined dictionary with module definition and name mapping
|
||||
config = {
|
||||
"module": self.mellon_config,
|
||||
"name_mapping": self.name_mapping
|
||||
}
|
||||
|
||||
# Save the config to file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_mellon_config(cls, file_path):
|
||||
"""
|
||||
Load a Mellon configuration from a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path to the JSON file containing Mellon config
|
||||
|
||||
Returns:
|
||||
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Mellon config loaded from {file_path}")
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def process_inputs(self, **kwargs):
|
||||
|
||||
params_components = {}
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
logger.debug(f"component: {comp_name}")
|
||||
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
||||
if mellon_comp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
||||
comp = kwargs[mellon_comp_name].pop(comp_name)
|
||||
else:
|
||||
comp = kwargs.pop(mellon_comp_name)
|
||||
if comp:
|
||||
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||
|
||||
|
||||
params_run = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
logger.debug(f"input: {inp_name}")
|
||||
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
||||
if mellon_inp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
||||
inp = kwargs[mellon_inp_name].pop(inp_name)
|
||||
else:
|
||||
inp = kwargs.pop(mellon_inp_name)
|
||||
if inp is not None:
|
||||
params_run[inp_name] = inp
|
||||
|
||||
return_output_names = list(self.config.output_params.keys())
|
||||
|
||||
return params_components, params_run, return_output_names
|
||||
|
||||
def execute(self, **kwargs):
|
||||
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||
|
||||
self.blocks.loader.update(**params_components)
|
||||
output = self.blocks.run(**params_run, output=return_output_names)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ else:
|
||||
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
|
||||
_import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"]
|
||||
_import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"]
|
||||
_import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -37,6 +38,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
|
||||
from .decoders import StableDiffusionXLAutoDecodeStep
|
||||
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step within the denoising loop that prepare the latent input for the denoiser"
|
||||
return "step within the denoising loop that prepare the latent input for the denoiser. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
|
||||
|
||||
|
||||
@property
|
||||
@@ -89,7 +89,7 @@ class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step within the denoising loop that prepare the latent input for the denoiser"
|
||||
return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
|
||||
|
||||
|
||||
@property
|
||||
@@ -165,7 +165,7 @@ class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoise the latents with guidance"
|
||||
"Step within the denoising loop that denoise the latents with guidance. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -269,7 +269,7 @@ class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
|
||||
return "step within the denoising loop that denoise the latents with guidance (with controlnet). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -458,7 +458,7 @@ class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
|
||||
return "step within the denoising loop that update the latents. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -521,7 +521,7 @@ class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
|
||||
return "step within the denoising loop that update the latents (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -622,7 +622,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process"
|
||||
"Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -683,21 +683,60 @@ class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLDenoiseLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n"
|
||||
)
|
||||
|
||||
# control_cond
|
||||
class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents with controlnet. \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n"
|
||||
)
|
||||
|
||||
# mask
|
||||
class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents(for inpainting task only). \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLDenoiseLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n"
|
||||
)
|
||||
# control_cond + mask
|
||||
class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"and at each iteration, it runs blocks defined in `blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n"
|
||||
)
|
||||
|
||||
|
||||
# all task without controlnet
|
||||
@@ -706,18 +745,45 @@ class StableDiffusionXLDenoiseStep(AutoPipelineBlocks):
|
||||
block_names = ["inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks."
|
||||
" - `StableDiffusionXLDenoiseStep` (denoise) is used when no mask is provided."
|
||||
" - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided."
|
||||
)
|
||||
|
||||
# all task with controlnet
|
||||
class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop]
|
||||
block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents with controlnet. "
|
||||
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks."
|
||||
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when no mask is provided."
|
||||
" - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
|
||||
)
|
||||
|
||||
# all task with or without controlnet
|
||||
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep]
|
||||
block_names = ["controlnet_denoise", "denoise"]
|
||||
block_trigger_inputs = ["controlnet_cond", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
|
||||
" - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)."
|
||||
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)."
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
from ..modular_pipeline_utils import InsertableOrderedDict
|
||||
|
||||
# Import all the necessary block classes
|
||||
from .denoise import (
|
||||
StableDiffusionXLAutoDenoiseStep,
|
||||
StableDiffusionXLDenoiseStep,
|
||||
StableDiffusionXLControlNetDenoiseStep
|
||||
StableDiffusionXLControlNetDenoiseStep,
|
||||
StableDiffusionXLDenoiseLoop,
|
||||
StableDiffusionXLInpaintDenoiseLoop
|
||||
)
|
||||
from .before_denoise import (
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
@@ -50,56 +51,53 @@ from .decoders import (
|
||||
|
||||
# YiYi notes: comment out for now, work on this later
|
||||
# block mapping
|
||||
TEXT2IMAGE_BLOCKS = OrderedDict([
|
||||
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("denoise", StableDiffusionXLDenoiseLoop),
|
||||
("decode", StableDiffusionXLDecodeStep)
|
||||
])
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = OrderedDict([
|
||||
IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("denoise", StableDiffusionXLDenoiseLoop),
|
||||
("decode", StableDiffusionXLDecodeStep)
|
||||
])
|
||||
|
||||
INPAINT_BLOCKS = OrderedDict([
|
||||
INPAINT_BLOCKS = InsertableOrderedDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("denoise", StableDiffusionXLInpaintDenoiseLoop),
|
||||
("decode", StableDiffusionXLInpaintDecodeStep)
|
||||
])
|
||||
|
||||
CONTROLNET_BLOCKS = OrderedDict([
|
||||
CONTROLNET_BLOCKS = InsertableOrderedDict([
|
||||
("controlnet_input", StableDiffusionXLControlNetInputStep),
|
||||
("denoise", StableDiffusionXLControlNetDenoiseStep),
|
||||
])
|
||||
|
||||
CONTROLNET_UNION_BLOCKS = OrderedDict([
|
||||
CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([
|
||||
("controlnet_input", StableDiffusionXLControlNetUnionInputStep),
|
||||
("denoise", StableDiffusionXLControlNetDenoiseStep),
|
||||
])
|
||||
|
||||
IP_ADAPTER_BLOCKS = OrderedDict([
|
||||
IP_ADAPTER_BLOCKS = InsertableOrderedDict([
|
||||
("ip_adapter", StableDiffusionXLIPAdapterStep),
|
||||
])
|
||||
|
||||
AUTO_BLOCKS = OrderedDict([
|
||||
AUTO_BLOCKS = InsertableOrderedDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
|
||||
@@ -108,11 +106,6 @@ AUTO_BLOCKS = OrderedDict([
|
||||
("decode", StableDiffusionXLAutoDecodeStep)
|
||||
])
|
||||
|
||||
AUTO_CORE_BLOCKS = OrderedDict([
|
||||
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
|
||||
("denoise", StableDiffusionXLAutoDenoiseStep),
|
||||
])
|
||||
|
||||
|
||||
SDXL_SUPPORTED_BLOCKS = {
|
||||
"text2img": TEXT2IMAGE_BLOCKS,
|
||||
@@ -15,13 +15,16 @@
|
||||
"""Utilities to dynamically load objects from the Hub."""
|
||||
|
||||
import importlib
|
||||
import signal
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib import request
|
||||
|
||||
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
||||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
||||
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
|
||||
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_diffusers_versions():
|
||||
@@ -154,15 +159,87 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path):
|
||||
def _raise_timeout_error(signum, frame):
|
||||
raise ValueError(
|
||||
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
||||
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
||||
)
|
||||
|
||||
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||
if trust_remote_code is None:
|
||||
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||
prev_sig_handler = None
|
||||
try:
|
||||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||
while trust_remote_code is None:
|
||||
answer = input(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||
f"Do you wish to run the custom code? [y/N] "
|
||||
)
|
||||
if answer.lower() in ["yes", "y", "1"]:
|
||||
trust_remote_code = True
|
||||
elif answer.lower() in ["no", "n", "0", ""]:
|
||||
trust_remote_code = False
|
||||
signal.alarm(0)
|
||||
except Exception:
|
||||
# OS which does not support signal.SIGALRM
|
||||
raise ValueError(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
finally:
|
||||
if prev_sig_handler is not None:
|
||||
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||
signal.alarm(0)
|
||||
elif has_remote_code:
|
||||
# For the CI which puts the timeout at 0
|
||||
_raise_timeout_error(None, None)
|
||||
|
||||
if has_remote_code and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||
" set the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
|
||||
return trust_remote_code
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path, force_reload=False):
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
module = importlib.import_module(module_path)
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
if force_reload:
|
||||
sys.modules.pop(name, None)
|
||||
importlib.invalidate_caches()
|
||||
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||
|
||||
module: ModuleType
|
||||
if cached_module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
else:
|
||||
module = cached_module
|
||||
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
if class_name is None:
|
||||
return find_pipeline_class(module)
|
||||
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
@@ -454,4 +531,4 @@ def get_class_from_dynamic_module(
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
return get_class_in_module(class_name, final_module)
|
||||
|
||||
Reference in New Issue
Block a user