mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
6 Commits
fix-mirror
...
single-mod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95dae4c91e | ||
|
|
cb62b4ff6b | ||
|
|
76d795a9a6 | ||
|
|
6b5ee298da | ||
|
|
062bb8dc0e | ||
|
|
5063e3b89d |
@@ -343,6 +343,7 @@ class ConfigMixin:
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
print("load_config() is called.")
|
||||
|
||||
if cls.config_name is None:
|
||||
raise ValueError(
|
||||
@@ -485,10 +486,18 @@ class ConfigMixin:
|
||||
|
||||
# remove attributes from orig class that cannot be expected
|
||||
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
||||
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
||||
if (
|
||||
isinstance(orig_cls_name, str)
|
||||
and orig_cls_name != cls.__name__
|
||||
and hasattr(diffusers_library, orig_cls_name)
|
||||
):
|
||||
orig_cls = getattr(diffusers_library, orig_cls_name)
|
||||
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
||||
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
||||
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
|
||||
raise ValueError(
|
||||
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
|
||||
)
|
||||
|
||||
# remove private attributes
|
||||
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
||||
|
||||
@@ -305,13 +305,22 @@ def maybe_raise_or_warn(
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
elif library_name not in LOADABLE_CLASSES.keys():
|
||||
# load custom component
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
@@ -323,7 +332,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
|
||||
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None
|
||||
class_obj,
|
||||
config,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
hub_repo_id=None,
|
||||
hub_revision=None,
|
||||
class_name=None,
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
@@ -331,11 +348,19 @@ def _get_pipeline_class(
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
elif hub_repo_id is not None:
|
||||
file_name = f"{custom_pipeline}.py"
|
||||
custom_pipeline = hub_repo_id
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
|
||||
custom_pipeline,
|
||||
module_file=file_name,
|
||||
class_name=class_name,
|
||||
hub_repo_id=hub_repo_id,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision if hub_revision is None else hub_revision,
|
||||
)
|
||||
|
||||
if class_obj != DiffusionPipeline:
|
||||
@@ -383,11 +408,18 @@ def load_sub_model(
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
revision: str = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
@@ -1080,11 +1112,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
@@ -1223,6 +1265,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
revision=revision,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1542,6 +1585,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
||||
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
||||
it will execute code present on the Hub on your local machine.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1569,6 +1616,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -1604,12 +1652,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
# optionally create a custom component <> custom file mapping
|
||||
custom_components = {}
|
||||
for component in folder_names:
|
||||
if config_dict[component][0] not in LOADABLE_CLASSES.keys():
|
||||
custom_components[component] = config_dict[component][0]
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
@@ -1636,12 +1689,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
|
||||
custom_pipeline = config_dict["_class_name"][0]
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
|
||||
@@ -1652,12 +1714,32 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
|
||||
load_components_from_hub = len(custom_components) > 0
|
||||
|
||||
if load_pipe_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
if load_components_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
# retrieve passed components that should not be downloaded
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
hub_repo_id=pretrained_model_name if load_pipe_from_hub else None,
|
||||
hub_revision=revision,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
@@ -1754,7 +1836,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# retrieve pipeline class from local file
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
|
||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||
|
||||
pipeline_class = getattr(diffusers, cls_name, None)
|
||||
|
||||
|
||||
@@ -862,6 +862,58 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
|
||||
assert output_str == "This is a test"
|
||||
|
||||
def test_remote_components(self):
|
||||
# make sure that trust remote code has to be passed
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components")
|
||||
|
||||
# Check that only loading custom componets "my_unet", "my_scheduler" works
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True
|
||||
)
|
||||
|
||||
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
|
||||
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
|
||||
assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"/home/patrick/tiny-stable-diffusion-xl-pipe", custom_pipeline="my_pipeline"
|
||||
)
|
||||
|
||||
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
|
||||
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
|
||||
assert pipeline.__class__.__name__ == "MyPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_remote_auto_custom_pipe(self):
|
||||
# make sure that trust remote code has to be passed
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all")
|
||||
|
||||
# Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True
|
||||
)
|
||||
|
||||
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
|
||||
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
|
||||
assert pipeline.__class__.__name__ == "MyPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_local_custom_pipeline_repo(self):
|
||||
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user