Compare commits

...

18 Commits

Author SHA1 Message Date
sayakpaul
a88d11bc90 resolve conflicts. 2025-11-06 10:29:24 +05:30
Sayak Paul
a9165eb749 Merge branch 'main' into requirements-custom-blocks 2025-11-03 12:12:08 +05:30
Sayak Paul
eeb3445444 Merge branch 'main' into requirements-custom-blocks 2025-11-01 08:36:16 +05:30
Sayak Paul
5b7d0dfab6 Merge branch 'main' into requirements-custom-blocks 2025-10-29 16:30:46 +05:30
sayakpaul
1de4402c26 up 2025-10-27 13:55:17 +05:30
sayakpaul
024c2b9839 Merge branch 'main' into requirements-custom-blocks 2025-10-27 11:56:00 +05:30
Sayak Paul
35d8d97c02 Merge branch 'main' into requirements-custom-blocks 2025-10-22 21:57:45 +05:30
Sayak Paul
e52cabeff2 Merge branch 'main' into requirements-custom-blocks 2025-10-22 06:23:40 +05:30
Sayak Paul
2c4d73d72d Merge branch 'main' into requirements-custom-blocks 2025-10-21 01:54:38 +05:30
sayakpaul
046be83946 up 2025-10-02 15:43:44 +05:30
Sayak Paul
b7fba892f5 Merge branch 'main' into requirements-custom-blocks 2025-09-23 13:35:49 +05:30
Sayak Paul
ecbd907e76 Merge branch 'main' into requirements-custom-blocks 2025-09-12 15:47:22 +05:30
Sayak Paul
d159ae025d Merge branch 'main' into requirements-custom-blocks 2025-09-02 10:04:22 +05:30
Sayak Paul
756a1567f5 Merge branch 'main' into requirements-custom-blocks 2025-08-29 08:03:00 +02:00
Sayak Paul
d2731ababa Merge branch 'main' into requirements-custom-blocks 2025-08-21 07:59:54 +05:30
sayakpaul
37d3887194 unify. 2025-08-20 12:09:33 +05:30
sayakpaul
127e9a39d8 up 2025-08-20 11:51:15 +05:30
sayakpaul
12ceecf077 feat: implement requirements validation for custom blocks. 2025-08-20 11:04:28 +05:30
3 changed files with 114 additions and 13 deletions

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -39,6 +39,7 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
format_components,
format_configs,
make_doc_string,
@@ -239,6 +240,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: Optional[Dict[str, str]] = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -301,6 +303,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
hub_kwargs_names = [
"cache_dir",
"force_download",
@@ -313,16 +328,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -347,8 +352,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -1132,6 +1142,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> Dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""

View File

@@ -19,9 +19,11 @@ from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -670,3 +672,86 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: Dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: Dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized