|
|
|
|
@@ -14,6 +14,7 @@
|
|
|
|
|
import importlib
|
|
|
|
|
import inspect
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import traceback
|
|
|
|
|
import warnings
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
@@ -28,10 +29,16 @@ from tqdm.auto import tqdm
|
|
|
|
|
from typing_extensions import Self
|
|
|
|
|
|
|
|
|
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
|
|
|
|
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
|
|
|
|
|
from ..pipelines.pipeline_loading_utils import (
|
|
|
|
|
LOADABLE_CLASSES,
|
|
|
|
|
_fetch_class_library_tuple,
|
|
|
|
|
_unwrap_model,
|
|
|
|
|
simple_get_class_obj,
|
|
|
|
|
)
|
|
|
|
|
from ..utils import PushToHubMixin, is_accelerate_available, logging
|
|
|
|
|
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
|
|
|
|
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
|
|
|
|
from ..utils.torch_utils import is_compiled_module
|
|
|
|
|
from .components_manager import ComponentsManager
|
|
|
|
|
from .modular_pipeline_utils import (
|
|
|
|
|
MODULAR_MODEL_CARD_TEMPLATE,
|
|
|
|
|
@@ -1819,29 +1826,111 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|
|
|
|
)
|
|
|
|
|
return pipeline
|
|
|
|
|
|
|
|
|
|
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
|
|
|
|
def save_pretrained(
|
|
|
|
|
self,
|
|
|
|
|
save_directory: str | os.PathLike,
|
|
|
|
|
safe_serialization: bool = True,
|
|
|
|
|
variant: str | None = None,
|
|
|
|
|
max_shard_size: int | str | None = None,
|
|
|
|
|
push_to_hub: bool = False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Save the pipeline to a directory. It does not save components, you need to save them separately.
|
|
|
|
|
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
|
|
|
|
|
[`~ModularPipeline.from_pretrained`] class method.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
save_directory (`str` or `os.PathLike`):
|
|
|
|
|
Path to the directory where the pipeline will be saved.
|
|
|
|
|
push_to_hub (`bool`, optional):
|
|
|
|
|
Whether to push the pipeline to the huggingface hub.
|
|
|
|
|
**kwargs: Additional arguments passed to `save_config()` method
|
|
|
|
|
Directory to save the pipeline to. Will be created if it doesn't exist.
|
|
|
|
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
|
|
|
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
|
|
|
|
variant (`str`, *optional*):
|
|
|
|
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
|
|
|
|
max_shard_size (`int` or `str`, defaults to `None`):
|
|
|
|
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
|
|
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
|
|
|
|
If expressed as an integer, the unit is bytes.
|
|
|
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether to push the pipeline to the Hugging Face model hub after saving it.
|
|
|
|
|
**kwargs: Additional keyword arguments passed along to the push to hub method.
|
|
|
|
|
"""
|
|
|
|
|
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
|
|
|
|
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
|
|
|
|
|
|
|
|
|
for component_name, component_spec in self._component_specs.items():
|
|
|
|
|
sub_model = getattr(self, component_name, None)
|
|
|
|
|
if sub_model is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
model_cls = sub_model.__class__
|
|
|
|
|
if is_compiled_module(sub_model):
|
|
|
|
|
sub_model = _unwrap_model(sub_model)
|
|
|
|
|
model_cls = sub_model.__class__
|
|
|
|
|
|
|
|
|
|
save_method_name = None
|
|
|
|
|
for library_name, library_classes in LOADABLE_CLASSES.items():
|
|
|
|
|
if library_name in sys.modules:
|
|
|
|
|
library = importlib.import_module(library_name)
|
|
|
|
|
else:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
|
|
|
|
|
)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
for base_class, save_load_methods in library_classes.items():
|
|
|
|
|
class_candidate = getattr(library, base_class, None)
|
|
|
|
|
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
|
|
|
|
save_method_name = save_load_methods[0]
|
|
|
|
|
break
|
|
|
|
|
if save_method_name is not None:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if save_method_name is None:
|
|
|
|
|
logger.warning(f"self.{component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
save_method = getattr(sub_model, save_method_name)
|
|
|
|
|
save_method_signature = inspect.signature(save_method)
|
|
|
|
|
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
|
|
|
|
save_method_accept_variant = "variant" in save_method_signature.parameters
|
|
|
|
|
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
|
|
|
|
|
|
|
|
|
save_kwargs = {}
|
|
|
|
|
if save_method_accept_safe:
|
|
|
|
|
save_kwargs["safe_serialization"] = safe_serialization
|
|
|
|
|
if save_method_accept_variant:
|
|
|
|
|
save_kwargs["variant"] = variant
|
|
|
|
|
if save_method_accept_max_shard_size and max_shard_size is not None:
|
|
|
|
|
save_kwargs["max_shard_size"] = max_shard_size
|
|
|
|
|
|
|
|
|
|
save_method(os.path.join(save_directory, component_name), **save_kwargs)
|
|
|
|
|
|
|
|
|
|
if push_to_hub:
|
|
|
|
|
commit_message = kwargs.pop("commit_message", None)
|
|
|
|
|
private = kwargs.pop("private", None)
|
|
|
|
|
create_pr = kwargs.pop("create_pr", False)
|
|
|
|
|
token = kwargs.pop("token", None)
|
|
|
|
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
|
|
|
|
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
|
|
|
|
|
|
|
|
|
# Generate modular pipeline card content
|
|
|
|
|
card_content = generate_modular_model_card_content(self.blocks)
|
|
|
|
|
if overwrite_modular_index:
|
|
|
|
|
for component_name, component_spec in self._component_specs.items():
|
|
|
|
|
if component_spec.default_creation_method != "from_pretrained":
|
|
|
|
|
continue
|
|
|
|
|
if component_name not in self.config:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Create a new empty model card and eventually tag it
|
|
|
|
|
library, class_name, component_spec_dict = self.config[component_name]
|
|
|
|
|
component_spec_dict["pretrained_model_name_or_path"] = repo_id
|
|
|
|
|
component_spec_dict["subfolder"] = component_name
|
|
|
|
|
if variant is not None and "variant" in component_spec_dict:
|
|
|
|
|
component_spec_dict["variant"] = variant
|
|
|
|
|
|
|
|
|
|
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
|
|
|
|
|
|
|
|
|
|
self.save_config(save_directory=save_directory)
|
|
|
|
|
|
|
|
|
|
if push_to_hub:
|
|
|
|
|
card_content = generate_modular_model_card_content(self.blocks)
|
|
|
|
|
model_card = load_or_create_model_card(
|
|
|
|
|
repo_id,
|
|
|
|
|
token=token,
|
|
|
|
|
@@ -1850,13 +1939,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|
|
|
|
is_modular=True,
|
|
|
|
|
)
|
|
|
|
|
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
|
|
|
|
|
|
|
|
|
model_card.save(os.path.join(save_directory, "README.md"))
|
|
|
|
|
|
|
|
|
|
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
|
|
|
|
self.save_config(save_directory=save_directory)
|
|
|
|
|
|
|
|
|
|
if push_to_hub:
|
|
|
|
|
self._upload_folder(
|
|
|
|
|
save_directory,
|
|
|
|
|
repo_id,
|
|
|
|
|
|