mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[Pipelines] Make sure that None functions are correctly not saved (#3080)
This commit is contained in:
committed by
GitHub
parent
d06e06940b
commit
46c52f9b96
@@ -19,6 +19,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = dict(self.config)
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_class_name", None)
|
||||
model_index_dict.pop("_diffusers_version", None)
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
return True
|
||||
|
||||
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
model_cls = sub_model.__class__
|
||||
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
save_method_name = None
|
||||
# search for the model's base class in LOADABLE_CLASSES
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
if library_name in sys.modules:
|
||||
library = importlib.import_module(library_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
|
||||
)
|
||||
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
||||
# make sure that unsaveable components are not tried to be loaded afterward
|
||||
self.register_to_config(**{pipeline_component_name: (None, None)})
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
|
||||
# Call the save method with the argument safe_serialization only if it's supported
|
||||
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
# finally save the config
|
||||
self.save_config(save_directory)
|
||||
|
||||
def to(
|
||||
self,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
|
||||
Reference in New Issue
Block a user