diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index ae049a8c0a..4a3f0b24b7 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline): modeling_file = "modeling_ddpm.py" - def __init__(self, unet, noise_scheduler): + def __init__(self, unet, noise_scheduler, vqvae): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) diff --git a/models/vision/glide/modeling_vqvae.py.py b/models/vision/glide/modeling_vqvae.py.py new file mode 100755 index 0000000000..e5a0d9b483 --- /dev/null +++ b/models/vision/glide/modeling_vqvae.py.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d4e050681a..6b56f78232 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -71,6 +71,10 @@ class DiffusionPipeline(ConfigMixin): for name, (library_name, class_name) in self._dict_to_save.items(): importable_classes = LOADABLE_CLASSES[library_name] + # TODO: Suraj + if library_name == self.__module__: + library_name = self + library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} @@ -91,12 +95,18 @@ class DiffusionPipeline(ConfigMixin): module = pipeline_kwargs["_module"] # TODO(Suraj) - make from hub import work + # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work + # Add Sylvains code from transformers init_kwargs = {} for name, (library_name, class_name) in config_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] + if library_name == module: + # TODO(Suraj) + pass + library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} @@ -110,7 +120,7 @@ class DiffusionPipeline(ConfigMixin): loaded_sub_model = load_method(os.path.join(cached_folder, name)) - init_kwargs[name] = loaded_sub_model + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) model = cls(**init_kwargs) return model