Compare commits

...

6 Commits

Author SHA1 Message Date
sayakpaul
95dae4c91e start 2023-10-23 10:06:45 +05:30
sayakpaul
cb62b4ff6b Merge remote-tracking branch 'origin/add_custom_remote_pipelines' into single-model-remote 2023-10-23 09:55:42 +05:30
Sayak Paul
76d795a9a6 Merge branch 'main' into add_custom_remote_pipelines 2023-10-20 21:51:16 +05:30
Patrick von Platen
6b5ee298da make style 2023-10-20 17:33:54 +02:00
Patrick von Platen
062bb8dc0e up 2023-10-20 17:30:52 +02:00
Patrick von Platen
5063e3b89d upload custom remote poc 2023-10-20 16:21:29 +02:00
3 changed files with 150 additions and 7 deletions

View File

@@ -343,6 +343,7 @@ class ConfigMixin:
user_agent = http_user_agent(user_agent)
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
print("load_config() is called.")
if cls.config_name is None:
raise ValueError(
@@ -485,10 +486,18 @@ class ConfigMixin:
# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
if (
isinstance(orig_cls_name, str)
and orig_cls_name != cls.__name__
and hasattr(diffusers_library, orig_cls_name)
):
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
raise ValueError(
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
)
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}

View File

@@ -305,13 +305,22 @@ def maybe_raise_or_warn(
)
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
elif library_name not in LOADABLE_CLASSES.keys():
# load custom component
component_folder = os.path.join(cache_dir, component_name)
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
@@ -323,7 +332,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
def _get_pipeline_class(
class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None
class_obj,
config,
load_connected_pipeline=False,
custom_pipeline=None,
hub_repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
@@ -331,11 +348,19 @@ def _get_pipeline_class(
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif hub_repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = hub_repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
return get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
custom_pipeline,
module_file=file_name,
class_name=class_name,
hub_repo_id=hub_repo_id,
cache_dir=cache_dir,
revision=revision if hub_revision is None else hub_revision,
)
if class_obj != DiffusionPipeline:
@@ -383,11 +408,18 @@ def load_sub_model(
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
revision: str = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module
library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
)
load_method_name = None
@@ -1080,11 +1112,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
custom_class_name = None
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
custom_class_name = config_dict["_class_name"][1]
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
@@ -1223,6 +1265,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
revision=revision,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1542,6 +1585,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as
it will execute code present on the Hub on your local machine.
Returns:
`os.PathLike`:
@@ -1569,6 +1616,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
allow_pickle = False
if use_safetensors is None:
@@ -1604,12 +1652,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
if config_dict[component][0] not in LOADABLE_CLASSES.keys():
custom_components[component] = config_dict[component][0]
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
@@ -1636,12 +1689,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
@@ -1652,12 +1714,32 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
CUSTOM_PIPELINE_FILE_NAME,
]
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0
if load_pipe_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
hub_repo_id=pretrained_model_name if load_pipe_from_hub else None,
hub_revision=revision,
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
@@ -1754,7 +1836,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
pipeline_class = getattr(diffusers, cls_name, None)

View File

@@ -862,6 +862,58 @@ class CustomPipelineTests(unittest.TestCase):
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test"
def test_remote_components(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components")
# Check that only loading custom componets "my_unet", "my_scheduler" works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True
)
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"
pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
assert images.shape == (1, 64, 64, 3)
# Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"/home/patrick/tiny-stable-diffusion-xl-pipe", custom_pipeline="my_pipeline"
)
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"
pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
assert images.shape == (1, 64, 64, 3)
def test_remote_auto_custom_pipe(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all")
# Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True
)
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"
pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
assert images.shape == (1, 64, 64, 3)
def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained(