mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-14 14:55:26 +08:00
add more logic for dynamic loading
This commit is contained in:
@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
|
||||
|
||||
modeling_file = "modeling_ddpm.py"
|
||||
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
def __init__(self, unet, noise_scheduler, vqvae):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
|
||||
1
models/vision/glide/modeling_vqvae.py.py
Executable file
1
models/vision/glide/modeling_vqvae.py.py
Executable file
@@ -0,0 +1 @@
|
||||
#!/usr/bin/env python3
|
||||
@@ -71,6 +71,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for name, (library_name, class_name) in self._dict_to_save.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
# TODO: Suraj
|
||||
if library_name == self.__module__:
|
||||
library_name = self
|
||||
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
@@ -91,12 +95,18 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
module = pipeline_kwargs["_module"]
|
||||
# TODO(Suraj) - make from hub import work
|
||||
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
|
||||
# Add Sylvains code from transformers
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
for name, (library_name, class_name) in config_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
if library_name == module:
|
||||
# TODO(Suraj)
|
||||
pass
|
||||
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
@@ -110,7 +120,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name))
|
||||
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
model = cls(**init_kwargs)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user