Compare commits

...

1 Commits

Author SHA1 Message Date
DN6
eff791831f update 2026-03-13 10:28:38 +05:30

View File

@@ -14,6 +14,7 @@
import importlib
import inspect
import os
import shutil
import sys
import traceback
import warnings
@@ -1883,6 +1884,36 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def _maybe_save_custom_code(self, save_directory: str | os.PathLike):
"""Save custom code files (blocks config and Python modules) to the save directory."""
if self._blocks is None:
return
blocks_module = type(self._blocks).__module__
is_custom_code = not blocks_module.startswith("diffusers.") and blocks_module != "diffusers"
if not is_custom_code:
return
os.makedirs(save_directory, exist_ok=True)
self._blocks.save_pretrained(save_directory)
source_file = inspect.getfile(type(self._blocks))
module_file = os.path.basename(source_file)
dest_file = os.path.join(save_directory, module_file)
if os.path.abspath(source_file) != os.path.abspath(dest_file):
shutil.copyfile(source_file, dest_file)
from ..utils.dynamic_modules_utils import get_relative_import_files
for rel_file in get_relative_import_files(source_file):
rel_name = os.path.relpath(rel_file, os.path.dirname(source_file))
rel_dest = os.path.join(save_directory, rel_name)
if os.path.abspath(rel_file) != os.path.abspath(rel_dest):
os.makedirs(os.path.dirname(rel_dest), exist_ok=True)
shutil.copyfile(rel_file, rel_dest)
def save_pretrained(
self,
save_directory: str | os.PathLike,
@@ -1998,6 +2029,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self._maybe_save_custom_code(save_directory)
self.save_config(save_directory=save_directory)
if push_to_hub: