[Pipelines] Make sure that None functions are correctly not saved (#3080)

This commit is contained in:
Patrick von Platen
2023-04-13 00:25:10 +02:00
committed by GitHub
parent d06e06940b
commit 46c52f9b96

View File

@@ -19,6 +19,7 @@ import importlib
import inspect
import os
import re
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_class_name", None)
model_index_dict.pop("_diffusers_version", None)
model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
if library_name in sys.modules:
library = importlib.import_module(library_name)
else:
logger.info(
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate):
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
if save_method_name is not None:
break
if save_method_name is None:
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
# make sure that unsaveable components are not tried to be loaded afterward
self.register_to_config(**{pipeline_component_name: (None, None)})
continue
save_method = getattr(sub_model, save_method_name)
# Call the save method with the argument safe_serialization only if it's supported
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
# finally save the config
self.save_config(save_directory)
def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,