mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 17:04:52 +08:00
Compare commits
18 Commits
pr-tests-f
...
requiremen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a88d11bc90 | ||
|
|
a9165eb749 | ||
|
|
eeb3445444 | ||
|
|
5b7d0dfab6 | ||
|
|
1de4402c26 | ||
|
|
024c2b9839 | ||
|
|
35d8d97c02 | ||
|
|
e52cabeff2 | ||
|
|
2c4d73d72d | ||
|
|
046be83946 | ||
|
|
b7fba892f5 | ||
|
|
ecbd907e76 | ||
|
|
d159ae025d | ||
|
|
756a1567f5 | ||
|
|
d2731ababa | ||
|
|
37d3887194 | ||
|
|
127e9a39d8 | ||
|
|
12ceecf077 |
@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
|||||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||||
# with open(CONFIG, "w") as f:
|
# with open(CONFIG, "w") as f:
|
||||||
# json.dump(automap, f)
|
# json.dump(automap, f)
|
||||||
with open("requirements.txt", "w") as f:
|
|
||||||
f.write("")
|
|
||||||
|
|
||||||
def _choose_block(self, candidates, chosen=None):
|
def _choose_block(self, candidates, chosen=None):
|
||||||
for cls, base in candidates:
|
for cls, base in candidates:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .modular_pipeline_utils import (
|
|||||||
InputParam,
|
InputParam,
|
||||||
InsertableDict,
|
InsertableDict,
|
||||||
OutputParam,
|
OutputParam,
|
||||||
|
_validate_requirements,
|
||||||
format_components,
|
format_components,
|
||||||
format_configs,
|
format_configs,
|
||||||
make_doc_string,
|
make_doc_string,
|
||||||
@@ -239,6 +240,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
config_name = "modular_config.json"
|
config_name = "modular_config.json"
|
||||||
model_name = None
|
model_name = None
|
||||||
|
_requirements: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_signature_keys(cls, obj):
|
def _get_signature_keys(cls, obj):
|
||||||
@@ -301,6 +303,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
config = cls.load_config(pretrained_model_name_or_path)
|
||||||
|
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||||
|
trust_remote_code = resolve_trust_remote_code(
|
||||||
|
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||||
|
)
|
||||||
|
if not (has_remote_code and trust_remote_code):
|
||||||
|
raise ValueError(
|
||||||
|
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "requirements" in config and config["requirements"] is not None:
|
||||||
|
_ = _validate_requirements(config["requirements"])
|
||||||
|
|
||||||
hub_kwargs_names = [
|
hub_kwargs_names = [
|
||||||
"cache_dir",
|
"cache_dir",
|
||||||
"force_download",
|
"force_download",
|
||||||
@@ -313,16 +328,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
]
|
]
|
||||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||||
|
|
||||||
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
|
|
||||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
|
||||||
trust_remote_code = resolve_trust_remote_code(
|
|
||||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
|
||||||
)
|
|
||||||
if not has_remote_code and trust_remote_code:
|
|
||||||
raise ValueError(
|
|
||||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
|
||||||
)
|
|
||||||
|
|
||||||
class_ref = config["auto_map"][cls.__name__]
|
class_ref = config["auto_map"][cls.__name__]
|
||||||
module_file, class_name = class_ref.split(".")
|
module_file, class_name = class_ref.split(".")
|
||||||
module_file = module_file + ".py"
|
module_file = module_file + ".py"
|
||||||
@@ -347,8 +352,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||||
|
|
||||||
self.register_to_config(auto_map=auto_map)
|
self.register_to_config(auto_map=auto_map)
|
||||||
|
|
||||||
|
# resolve requirements
|
||||||
|
requirements = _validate_requirements(getattr(self, "_requirements", None))
|
||||||
|
if requirements:
|
||||||
|
self.register_to_config(requirements=requirements)
|
||||||
|
|
||||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||||
config = dict(self.config)
|
config = dict(self.config)
|
||||||
self._internal_dict = FrozenDict(config)
|
self._internal_dict = FrozenDict(config)
|
||||||
@@ -1132,6 +1142,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
expected_configs=self.expected_configs,
|
expected_configs=self.expected_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _requirements(self) -> Dict[str, str]:
|
||||||
|
requirements = {}
|
||||||
|
for block_name, block in self.sub_blocks.items():
|
||||||
|
if getattr(block, "_requirements", None):
|
||||||
|
requirements[block_name] = block._requirements
|
||||||
|
return requirements
|
||||||
|
|
||||||
|
|
||||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -19,9 +19,11 @@ from dataclasses import dataclass, field, fields
|
|||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||||
from ..utils import is_torch_available, logging
|
from ..utils import is_torch_available, logging
|
||||||
|
from ..utils.import_utils import _is_package_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -670,3 +672,86 @@ def make_doc_string(
|
|||||||
output += format_output_params(outputs, indent_level=2)
|
output += format_output_params(outputs, indent_level=2)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_requirements(reqs):
|
||||||
|
if reqs is None:
|
||||||
|
normalized_reqs = {}
|
||||||
|
else:
|
||||||
|
if not isinstance(reqs, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"Requirements must be provided as a dictionary mapping package names to version specifiers."
|
||||||
|
)
|
||||||
|
normalized_reqs = _normalize_requirements(reqs)
|
||||||
|
|
||||||
|
if not normalized_reqs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
final: Dict[str, str] = {}
|
||||||
|
for req, specified_ver in normalized_reqs.items():
|
||||||
|
req_available, req_actual_ver = _is_package_available(req)
|
||||||
|
if not req_available:
|
||||||
|
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
|
||||||
|
|
||||||
|
if specified_ver:
|
||||||
|
try:
|
||||||
|
specifier = SpecifierSet(specified_ver)
|
||||||
|
except InvalidSpecifier as err:
|
||||||
|
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
|
||||||
|
|
||||||
|
if req_actual_ver == "N/A":
|
||||||
|
logger.warning(
|
||||||
|
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
|
||||||
|
)
|
||||||
|
elif not specifier.contains(req_actual_ver, prereleases=True):
|
||||||
|
logger.warning(
|
||||||
|
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
|
||||||
|
)
|
||||||
|
|
||||||
|
final[req] = specified_ver
|
||||||
|
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_requirements(reqs):
|
||||||
|
if not reqs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
normalized: "OrderedDict[str, str]" = OrderedDict()
|
||||||
|
|
||||||
|
def _accumulate(mapping: Dict[str, Any]):
|
||||||
|
for pkg, spec in mapping.items():
|
||||||
|
if isinstance(spec, dict):
|
||||||
|
# This is recursive because blocks are composable. This way, we can merge requirements
|
||||||
|
# from multiple blocks.
|
||||||
|
_accumulate(spec)
|
||||||
|
continue
|
||||||
|
|
||||||
|
pkg_name = str(pkg).strip()
|
||||||
|
if not pkg_name:
|
||||||
|
raise ValueError("Requirement package name cannot be empty.")
|
||||||
|
|
||||||
|
spec_str = "" if spec is None else str(spec).strip()
|
||||||
|
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
|
||||||
|
spec_str = f"=={spec_str}"
|
||||||
|
|
||||||
|
existing_spec = normalized.get(pkg_name)
|
||||||
|
if existing_spec is not None:
|
||||||
|
if not existing_spec and spec_str:
|
||||||
|
normalized[pkg_name] = spec_str
|
||||||
|
elif existing_spec and spec_str and existing_spec != spec_str:
|
||||||
|
try:
|
||||||
|
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
|
||||||
|
except InvalidSpecifier:
|
||||||
|
logger.warning(
|
||||||
|
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
normalized[pkg_name] = str(combined_spec)
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized[pkg_name] = spec_str
|
||||||
|
|
||||||
|
_accumulate(reqs)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|||||||
Reference in New Issue
Block a user