Compare commits

..

22 Commits

Author SHA1 Message Date
sayakpaul
79fa0e2bd5 resolve merge conflicts. 2026-02-16 11:06:09 +05:30
Sayak Paul
60e3284003 Merge branch 'main' into requirements-custom-blocks 2026-01-20 19:10:24 +05:30
sayakpaul
7b43d0e409 add tests 2026-01-20 09:29:32 +05:30
Sayak Paul
3879e32254 Merge branch 'main' into requirements-custom-blocks 2026-01-20 08:20:38 +05:30
sayakpaul
a88d11bc90 resolve conflicts. 2025-11-06 10:29:24 +05:30
Sayak Paul
a9165eb749 Merge branch 'main' into requirements-custom-blocks 2025-11-03 12:12:08 +05:30
Sayak Paul
eeb3445444 Merge branch 'main' into requirements-custom-blocks 2025-11-01 08:36:16 +05:30
Sayak Paul
5b7d0dfab6 Merge branch 'main' into requirements-custom-blocks 2025-10-29 16:30:46 +05:30
sayakpaul
1de4402c26 up 2025-10-27 13:55:17 +05:30
sayakpaul
024c2b9839 Merge branch 'main' into requirements-custom-blocks 2025-10-27 11:56:00 +05:30
Sayak Paul
35d8d97c02 Merge branch 'main' into requirements-custom-blocks 2025-10-22 21:57:45 +05:30
Sayak Paul
e52cabeff2 Merge branch 'main' into requirements-custom-blocks 2025-10-22 06:23:40 +05:30
Sayak Paul
2c4d73d72d Merge branch 'main' into requirements-custom-blocks 2025-10-21 01:54:38 +05:30
sayakpaul
046be83946 up 2025-10-02 15:43:44 +05:30
Sayak Paul
b7fba892f5 Merge branch 'main' into requirements-custom-blocks 2025-09-23 13:35:49 +05:30
Sayak Paul
ecbd907e76 Merge branch 'main' into requirements-custom-blocks 2025-09-12 15:47:22 +05:30
Sayak Paul
d159ae025d Merge branch 'main' into requirements-custom-blocks 2025-09-02 10:04:22 +05:30
Sayak Paul
756a1567f5 Merge branch 'main' into requirements-custom-blocks 2025-08-29 08:03:00 +02:00
Sayak Paul
d2731ababa Merge branch 'main' into requirements-custom-blocks 2025-08-21 07:59:54 +05:30
sayakpaul
37d3887194 unify. 2025-08-20 12:09:33 +05:30
sayakpaul
127e9a39d8 up 2025-08-20 11:51:15 +05:30
sayakpaul
12ceecf077 feat: implement requirements validation for custom blocks. 2025-08-20 11:04:28 +05:30
6 changed files with 192 additions and 16 deletions

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -40,6 +40,7 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
combine_inputs,
combine_outputs,
format_components,
@@ -290,6 +291,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: dict[str, str] | None = None
_workflow_map = None
@classmethod
@@ -382,6 +384,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
hub_kwargs_names = [
"cache_dir",
"force_download",
@@ -394,16 +409,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -428,8 +433,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -1240,6 +1250,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""

View File

@@ -22,10 +22,12 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
import PIL.Image
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -972,6 +974,89 @@ def make_doc_string(
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current

View File

@@ -81,7 +81,7 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
@torch.no_grad()
def test_torch_compile_repeated_blocks(self, recompile_limit=1):
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
@@ -92,6 +92,7 @@ class TorchCompileTesterMixin:
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2

View File

@@ -147,7 +147,22 @@ class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCom
def test_torch_compile_repeated_blocks(self):
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
# so we need recompile_limit=2 instead of the default 1.
super().test_torch_compile_repeated_blocks(recompile_limit=2)
import torch._dynamo
import torch._inductor.utils
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=2),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):

View File

@@ -1,4 +1,6 @@
import gc
import json
import os
import tempfile
from typing import Callable
@@ -8,6 +10,7 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -17,7 +20,13 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
)
from diffusers.utils import logging
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
torch_device,
)
class ModularPipelineTesterMixin:
@@ -400,6 +409,56 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def test_custom_requirements_save_load(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
config_path = os.path.join(tmpdir, "modular_config.json")
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_warnings(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(tmpdir)
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock: