Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
2185031c44 Merge branch 'main' into controlnet-compile-fix 2024-01-03 21:46:57 +05:30
Dhruv Nair
99f34ad65f update 2024-01-03 10:35:59 +00:00

View File

@@ -530,6 +530,36 @@ def load_sub_model(
return loaded_sub_model
def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
# register the config from the original module, not the dynamo compiled one
not_compiled_module = _unwrap_model(module)
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__
# retrieve class_name
class_name = not_compiled_module.__class__.__name__
return (library, class_name)
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
r"""
Base class for all pipelines.
@@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_is_onnx = False
def register_modules(self, **kwargs):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
for name, module in kwargs.items():
# retrieve library
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
register_dict = {name: (None, None)}
else:
# register the config from the original module, not the dynamo compiled one
not_compiled_module = _unwrap_model(module)
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__
# retrieve class_name
class_name = not_compiled_module.__class__.__name__
library, class_name = _fetch_class_library_tuple(module)
register_dict = {name: (library, class_name)}
# save model index config
@@ -601,7 +605,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if value is not None and self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
class_library_tuple = _fetch_class_library_tuple(value)
else:
class_library_tuple = (None, None)