Compare commits

...

18 Commits

Author SHA1 Message Date
sayakpaul
a88d11bc90 resolve conflicts. 2025-11-06 10:29:24 +05:30
Sayak Paul
a9165eb749 Merge branch 'main' into requirements-custom-blocks 2025-11-03 12:12:08 +05:30
Sayak Paul
eeb3445444 Merge branch 'main' into requirements-custom-blocks 2025-11-01 08:36:16 +05:30
Sayak Paul
5b7d0dfab6 Merge branch 'main' into requirements-custom-blocks 2025-10-29 16:30:46 +05:30
sayakpaul
1de4402c26 up 2025-10-27 13:55:17 +05:30
sayakpaul
024c2b9839 Merge branch 'main' into requirements-custom-blocks 2025-10-27 11:56:00 +05:30
Sayak Paul
35d8d97c02 Merge branch 'main' into requirements-custom-blocks 2025-10-22 21:57:45 +05:30
Sayak Paul
e52cabeff2 Merge branch 'main' into requirements-custom-blocks 2025-10-22 06:23:40 +05:30
Sayak Paul
2c4d73d72d Merge branch 'main' into requirements-custom-blocks 2025-10-21 01:54:38 +05:30
sayakpaul
046be83946 up 2025-10-02 15:43:44 +05:30
Sayak Paul
b7fba892f5 Merge branch 'main' into requirements-custom-blocks 2025-09-23 13:35:49 +05:30
Sayak Paul
ecbd907e76 Merge branch 'main' into requirements-custom-blocks 2025-09-12 15:47:22 +05:30
Sayak Paul
d159ae025d Merge branch 'main' into requirements-custom-blocks 2025-09-02 10:04:22 +05:30
Sayak Paul
756a1567f5 Merge branch 'main' into requirements-custom-blocks 2025-08-29 08:03:00 +02:00
Sayak Paul
d2731ababa Merge branch 'main' into requirements-custom-blocks 2025-08-21 07:59:54 +05:30
sayakpaul
37d3887194 unify. 2025-08-20 12:09:33 +05:30
sayakpaul
127e9a39d8 up 2025-08-20 11:51:15 +05:30
sayakpaul
12ceecf077 feat: implement requirements validation for custom blocks. 2025-08-20 11:04:28 +05:30
3 changed files with 114 additions and 13 deletions

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class) # automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f: # with open(CONFIG, "w") as f:
# json.dump(automap, f) # json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None): def _choose_block(self, candidates, chosen=None):
for cls, base in candidates: for cls, base in candidates:

View File

@@ -39,6 +39,7 @@ from .modular_pipeline_utils import (
InputParam, InputParam,
InsertableDict, InsertableDict,
OutputParam, OutputParam,
_validate_requirements,
format_components, format_components,
format_configs, format_configs,
make_doc_string, make_doc_string,
@@ -239,6 +240,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json" config_name = "modular_config.json"
model_name = None model_name = None
_requirements: Optional[Dict[str, str]] = None
@classmethod @classmethod
def _get_signature_keys(cls, obj): def _get_signature_keys(cls, obj):
@@ -301,6 +303,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code: bool = False, trust_remote_code: bool = False,
**kwargs, **kwargs,
): ):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
hub_kwargs_names = [ hub_kwargs_names = [
"cache_dir", "cache_dir",
"force_download", "force_download",
@@ -313,16 +328,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
] ]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__] class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".") module_file, class_name = class_ref.split(".")
module_file = module_file + ".py" module_file = module_file + ".py"
@@ -347,8 +352,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"} auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map) self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config) config = dict(self.config)
self._internal_dict = FrozenDict(config) self._internal_dict = FrozenDict(config)
@@ -1132,6 +1142,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs, expected_configs=self.expected_configs,
) )
@property
def _requirements(self) -> Dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks): class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
""" """

View File

@@ -19,9 +19,11 @@ from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import is_torch_available, logging from ..utils import is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available(): if is_torch_available():
@@ -670,3 +672,86 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2) output += format_output_params(outputs, indent_level=2)
return output return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: Dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: Dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized