mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
2 Commits
export-vid
...
controlnet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2185031c44 | ||
|
|
99f34ad65f |
@@ -530,6 +530,36 @@ def load_sub_model(
|
|||||||
return loaded_sub_model
|
return loaded_sub_model
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_class_library_tuple(module):
|
||||||
|
# import it here to avoid circular import
|
||||||
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||||
|
pipelines = getattr(diffusers_module, "pipelines")
|
||||||
|
|
||||||
|
# register the config from the original module, not the dynamo compiled one
|
||||||
|
not_compiled_module = _unwrap_model(module)
|
||||||
|
library = not_compiled_module.__module__.split(".")[0]
|
||||||
|
|
||||||
|
# check if the module is a pipeline module
|
||||||
|
module_path_items = not_compiled_module.__module__.split(".")
|
||||||
|
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||||
|
|
||||||
|
path = not_compiled_module.__module__.split(".")
|
||||||
|
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||||
|
|
||||||
|
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||||
|
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||||
|
# folder so we set the library to module name.
|
||||||
|
if is_pipeline_module:
|
||||||
|
library = pipeline_dir
|
||||||
|
elif library not in LOADABLE_CLASSES:
|
||||||
|
library = not_compiled_module.__module__
|
||||||
|
|
||||||
|
# retrieve class_name
|
||||||
|
class_name = not_compiled_module.__class__.__name__
|
||||||
|
|
||||||
|
return (library, class_name)
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||||
r"""
|
r"""
|
||||||
Base class for all pipelines.
|
Base class for all pipelines.
|
||||||
@@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
_is_onnx = False
|
_is_onnx = False
|
||||||
|
|
||||||
def register_modules(self, **kwargs):
|
def register_modules(self, **kwargs):
|
||||||
# import it here to avoid circular import
|
|
||||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
|
||||||
pipelines = getattr(diffusers_module, "pipelines")
|
|
||||||
|
|
||||||
for name, module in kwargs.items():
|
for name, module in kwargs.items():
|
||||||
# retrieve library
|
# retrieve library
|
||||||
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
|
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
|
||||||
register_dict = {name: (None, None)}
|
register_dict = {name: (None, None)}
|
||||||
else:
|
else:
|
||||||
# register the config from the original module, not the dynamo compiled one
|
library, class_name = _fetch_class_library_tuple(module)
|
||||||
not_compiled_module = _unwrap_model(module)
|
|
||||||
|
|
||||||
library = not_compiled_module.__module__.split(".")[0]
|
|
||||||
|
|
||||||
# check if the module is a pipeline module
|
|
||||||
module_path_items = not_compiled_module.__module__.split(".")
|
|
||||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
|
||||||
|
|
||||||
path = not_compiled_module.__module__.split(".")
|
|
||||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
|
||||||
|
|
||||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
|
||||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
|
||||||
# folder so we set the library to module name.
|
|
||||||
if is_pipeline_module:
|
|
||||||
library = pipeline_dir
|
|
||||||
elif library not in LOADABLE_CLASSES:
|
|
||||||
library = not_compiled_module.__module__
|
|
||||||
|
|
||||||
# retrieve class_name
|
|
||||||
class_name = not_compiled_module.__class__.__name__
|
|
||||||
|
|
||||||
register_dict = {name: (library, class_name)}
|
register_dict = {name: (library, class_name)}
|
||||||
|
|
||||||
# save model index config
|
# save model index config
|
||||||
@@ -601,7 +605,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
# We need to overwrite the config if name exists in config
|
# We need to overwrite the config if name exists in config
|
||||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||||
if value is not None and self.config[name][0] is not None:
|
if value is not None and self.config[name][0] is not None:
|
||||||
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
|
class_library_tuple = _fetch_class_library_tuple(value)
|
||||||
else:
|
else:
|
||||||
class_library_tuple = (None, None)
|
class_library_tuple = (None, None)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user