Compare commits

...

20 Commits

Author SHA1 Message Date
sayakpaul
7b43d0e409 add tests 2026-01-20 09:29:32 +05:30
Sayak Paul
3879e32254 Merge branch 'main' into requirements-custom-blocks 2026-01-20 08:20:38 +05:30
sayakpaul
a88d11bc90 resolve conflicts. 2025-11-06 10:29:24 +05:30
Sayak Paul
a9165eb749 Merge branch 'main' into requirements-custom-blocks 2025-11-03 12:12:08 +05:30
Sayak Paul
eeb3445444 Merge branch 'main' into requirements-custom-blocks 2025-11-01 08:36:16 +05:30
Sayak Paul
5b7d0dfab6 Merge branch 'main' into requirements-custom-blocks 2025-10-29 16:30:46 +05:30
sayakpaul
1de4402c26 up 2025-10-27 13:55:17 +05:30
sayakpaul
024c2b9839 Merge branch 'main' into requirements-custom-blocks 2025-10-27 11:56:00 +05:30
Sayak Paul
35d8d97c02 Merge branch 'main' into requirements-custom-blocks 2025-10-22 21:57:45 +05:30
Sayak Paul
e52cabeff2 Merge branch 'main' into requirements-custom-blocks 2025-10-22 06:23:40 +05:30
Sayak Paul
2c4d73d72d Merge branch 'main' into requirements-custom-blocks 2025-10-21 01:54:38 +05:30
sayakpaul
046be83946 up 2025-10-02 15:43:44 +05:30
Sayak Paul
b7fba892f5 Merge branch 'main' into requirements-custom-blocks 2025-09-23 13:35:49 +05:30
Sayak Paul
ecbd907e76 Merge branch 'main' into requirements-custom-blocks 2025-09-12 15:47:22 +05:30
Sayak Paul
d159ae025d Merge branch 'main' into requirements-custom-blocks 2025-09-02 10:04:22 +05:30
Sayak Paul
756a1567f5 Merge branch 'main' into requirements-custom-blocks 2025-08-29 08:03:00 +02:00
Sayak Paul
d2731ababa Merge branch 'main' into requirements-custom-blocks 2025-08-21 07:59:54 +05:30
sayakpaul
37d3887194 unify. 2025-08-20 12:09:33 +05:30
sayakpaul
127e9a39d8 up 2025-08-20 11:51:15 +05:30
sayakpaul
12ceecf077 feat: implement requirements validation for custom blocks. 2025-08-20 11:04:28 +05:30
4 changed files with 174 additions and 14 deletions

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -39,6 +39,7 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
format_components,
format_configs,
make_doc_string,
@@ -242,6 +243,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: Optional[Dict[str, str]] = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -304,6 +306,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
hub_kwargs_names = [
"cache_dir",
"force_download",
@@ -316,16 +331,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -350,8 +355,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -1154,6 +1164,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> Dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""

View File

@@ -19,10 +19,12 @@ from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -690,3 +692,86 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: Dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: Dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized

View File

@@ -1,4 +1,6 @@
import gc
import json
import os
import tempfile
from typing import Callable, Union
@@ -8,9 +10,16 @@ import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.utils import logging
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
torch_device,
)
class ModularPipelineTesterMixin:
@@ -335,3 +344,53 @@ class ModularGuiderTesterMixin:
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def test_custom_requirements_save_load(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
config_path = os.path.join(tmpdir, "modular_config.json")
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_warnings(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(tmpdir)
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)