mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
19 Commits
torchao-in
...
variants-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f35f83b9cd | ||
|
|
f56880506f | ||
|
|
30628b482f | ||
|
|
a29f742ddd | ||
|
|
02b089206b | ||
|
|
b79e720f52 | ||
|
|
3db5a69b9f | ||
|
|
6899f400d5 | ||
|
|
abba8e0ff8 | ||
|
|
420c78cb90 | ||
|
|
ac4c23c154 | ||
|
|
c40f60cd46 | ||
|
|
04d7dc3afa | ||
|
|
a4bdc970ca | ||
|
|
2089700d4b | ||
|
|
9f9db3bfc8 | ||
|
|
974f67e1e2 | ||
|
|
9f0ae2f523 | ||
|
|
403417e926 |
@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
passed_components = passed_components or []
|
||||
if folder_names is not None:
|
||||
if folder_names:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# extract all components of the pipeline and their associated files
|
||||
@@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
return True
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
def filter_model_files(filenames):
|
||||
"""Filter model repo files for just files/folders that contain model weights"""
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
@@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
|
||||
legacy_variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
@@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
|
||||
if not ignore_patterns:
|
||||
return filenames
|
||||
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
|
||||
# interested in the extension name
|
||||
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
|
||||
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
def find_component(filename):
|
||||
# Group files by component
|
||||
components = {}
|
||||
for filename in filenames:
|
||||
if not len(filename.split("/")) == 2:
|
||||
return
|
||||
component = filename.split("/")[0]
|
||||
return component
|
||||
components.setdefault("", []).append(filename)
|
||||
continue
|
||||
|
||||
def has_sharded_variant(component, variant, variant_filenames):
|
||||
# If component exists check for sharded variant index filename
|
||||
# If component doesn't exist check main dir for sharded variant index filename
|
||||
component = component + "/" if component else ""
|
||||
variant_index_re = re.compile(
|
||||
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
component, _ = filename.split("/")
|
||||
components.setdefault(component, []).append(filename)
|
||||
|
||||
usable_filenames = set()
|
||||
variant_filenames = set()
|
||||
for component, component_filenames in components.items():
|
||||
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
|
||||
|
||||
component_variants = set()
|
||||
component_legacy_variants = set()
|
||||
component_non_variants = set()
|
||||
if variant is not None:
|
||||
component_variants = filter_with_regex(component_filenames, variant_file_re)
|
||||
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
|
||||
|
||||
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
|
||||
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
|
||||
|
||||
if component_variants or component_legacy_variants:
|
||||
variant_filenames.update(
|
||||
component_variants | component_variant_index_files
|
||||
if component_variants
|
||||
else component_legacy_variants | component_legacy_variant_index_files
|
||||
)
|
||||
|
||||
else:
|
||||
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
|
||||
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
|
||||
|
||||
usable_filenames.update(component_non_variants | component_variant_index_files)
|
||||
|
||||
usable_filenames.update(variant_filenames)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
|
||||
raise ValueError(error_message)
|
||||
|
||||
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
||||
|
||||
for filename in non_variant_filenames:
|
||||
if convert_to_variant(filename) in variant_filenames:
|
||||
continue
|
||||
|
||||
component = find_component(filename)
|
||||
# If a sharded variant exists skip adding to allowed patterns
|
||||
if has_sharded_variant(component, variant, variant_filenames):
|
||||
continue
|
||||
|
||||
usable_filenames.add(filename)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
@@ -922,10 +958,6 @@ def _get_custom_components_and_folders(
|
||||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
return custom_components, folder_names
|
||||
|
||||
|
||||
@@ -933,7 +965,6 @@ def _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names: List[str],
|
||||
model_filenames: List[str],
|
||||
variant_filenames: List[str],
|
||||
use_safetensors: bool,
|
||||
from_flax: bool,
|
||||
allow_pickle: bool,
|
||||
@@ -964,16 +995,6 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
@@ -981,16 +1002,6 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
|
||||
f"your folder structure."
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ from .pipeline_loading_utils import (
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
filter_model_files,
|
||||
load_sub_model,
|
||||
maybe_raise_or_warn,
|
||||
variant_compatible_siblings,
|
||||
@@ -1387,10 +1388,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
|
||||
use_safetensors = use_safetensors if use_safetensors is not None else True
|
||||
|
||||
allow_patterns = None
|
||||
ignore_patterns = None
|
||||
@@ -1405,6 +1404,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
if not local_files_only:
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
)
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
||||
warn_msg = (
|
||||
@@ -1419,61 +1430,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
filenames = set(filenames) - set(ignore_filenames)
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.22.0"):
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
|
||||
|
||||
custom_components, folder_names = _get_custom_components_and_folders(
|
||||
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
||||
pretrained_model_name, config_dict, filenames, variant
|
||||
)
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
|
||||
custom_pipeline = config_dict["_class_name"][0]
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
# also allow downloading generation_config.json of the transformers model
|
||||
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
|
||||
load_components_from_hub = len(custom_components) > 0
|
||||
|
||||
@@ -1506,12 +1476,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
||||
passed_components = [k for k in expected_components if k in kwargs]
|
||||
|
||||
# retrieve the names of the folders containing model weights
|
||||
model_folder_names = {
|
||||
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
|
||||
}
|
||||
# retrieve all patterns that should not be downloaded and error out when needed
|
||||
ignore_patterns = _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names,
|
||||
model_filenames,
|
||||
variant_filenames,
|
||||
filenames,
|
||||
use_safetensors,
|
||||
from_flax,
|
||||
allow_pickle,
|
||||
@@ -1520,6 +1493,29 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant,
|
||||
)
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
# Don't download any objects that are passed
|
||||
allow_patterns = [
|
||||
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
|
||||
|
||||
@@ -212,6 +212,7 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
|
||||
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
def test_only_non_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -222,10 +223,13 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_only_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -236,10 +240,13 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
non_variant_file = "text_encoder/model.safetensors"
|
||||
filenames = [
|
||||
@@ -249,23 +256,27 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_non_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"model.{variant}.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -275,23 +286,76 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
non_variant_file = "model.safetensors"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_non_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
@@ -302,10 +366,13 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
@@ -316,10 +383,49 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
assert model_filenames == variant_filenames
|
||||
|
||||
def test_single_variant_with_sharded_non_variant_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_mixed_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
@@ -335,9 +441,144 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_downloading_when_no_variant_exists(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
|
||||
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
def test_downloading_use_safetensors_false(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
filenames = [
|
||||
"text_encoder/model.bin",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
assert all(".safetensors" not in f for f in model_filenames)
|
||||
|
||||
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "diffusion_pytorch_model.safetensors"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_download_variants_when_component_has_no_safetensors_variant(self):
|
||||
ignore_patterns = None
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.{variant}.bin",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert {
|
||||
f"unet/diffusion_pytorch_model.{variant}.bin",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
} == model_filenames
|
||||
|
||||
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
|
||||
]
|
||||
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
allowed_non_variant = "unet"
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_download_sharded_legacy_variants(self):
|
||||
ignore_patterns = None
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_download_onnx_models(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
filenames = [
|
||||
"vae/model.onnx",
|
||||
"unet/model.onnx",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert model_filenames == set(filenames)
|
||||
|
||||
def test_download_flax_models(self):
|
||||
ignore_patterns = ["*.safetensors", "*.bin"]
|
||||
filenames = [
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert model_filenames == set(filenames)
|
||||
|
||||
|
||||
class ProgressBarTests(unittest.TestCase):
|
||||
def get_dummy_components_image_generation(self):
|
||||
|
||||
Reference in New Issue
Block a user