Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
ca39584566 fix single file config loading when passing dicts. 2024-05-13 19:17:42 +05:30

View File

@@ -193,6 +193,7 @@ class FromOriginalModelMixin:
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
diffusers_model_config = None
if original_config: if original_config:
if "config_mapping_fn" in mapping_functions: if "config_mapping_fn" in mapping_functions:
config_mapping_fn = mapping_functions["config_mapping_fn"] config_mapping_fn = mapping_functions["config_mapping_fn"]
@@ -220,6 +221,8 @@ class FromOriginalModelMixin:
if config: if config:
if isinstance(config, str): if isinstance(config, str):
default_pretrained_model_config_name = config default_pretrained_model_config_name = config
elif isinstance(config, dict):
diffusers_model_config = config
else: else:
raise ValueError( raise ValueError(
( (
@@ -239,11 +242,12 @@ class FromOriginalModelMixin:
"subfolder", None "subfolder", None
) # some configs contain a subfolder key, e.g. StableCascadeUNet ) # some configs contain a subfolder key, e.g. StableCascadeUNet
diffusers_model_config = cls.load_config( if diffusers_model_config is None:
pretrained_model_name_or_path=default_pretrained_model_config_name, diffusers_model_config = cls.load_config(
subfolder=subfolder, pretrained_model_name_or_path=default_pretrained_model_config_name,
local_files_only=local_files_only, subfolder=subfolder,
) local_files_only=local_files_only,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
# Map legacy kwargs to new kwargs # Map legacy kwargs to new kwargs