Compare commits

...

19 Commits

Author SHA1 Message Date
DN6
f35f83b9cd update 2025-03-04 06:26:15 +05:30
DN6
f56880506f update 2025-03-04 06:05:14 +05:30
Dhruv Nair
30628b482f Merge branch 'main' into variants-fetching-fix 2025-02-27 14:27:34 +05:30
Dhruv Nair
a29f742ddd Merge branch 'main' into variants-fetching-fix 2025-02-25 09:54:35 +05:30
DN6
02b089206b update 2025-02-21 12:33:42 +05:30
DN6
b79e720f52 Merge branch 'main' into variants-fetching-fix 2025-02-21 10:51:42 +05:30
DN6
3db5a69b9f update 2025-02-20 23:59:27 +05:30
DN6
6899f400d5 update 2025-02-20 23:43:19 +05:30
DN6
abba8e0ff8 update 2025-02-20 22:33:37 +05:30
DN6
420c78cb90 update 2025-02-20 21:38:27 +05:30
Dhruv Nair
ac4c23c154 update 2025-02-20 13:57:18 +01:00
Dhruv Nair
c40f60cd46 update 2025-02-20 11:10:48 +01:00
DN6
04d7dc3afa update 2025-01-29 21:54:30 +05:30
DN6
a4bdc970ca update 2025-01-29 19:08:18 +05:30
DN6
2089700d4b update 2025-01-29 11:37:30 +05:30
Dhruv Nair
9f9db3bfc8 update 2025-01-24 17:37:28 +01:00
Dhruv Nair
974f67e1e2 update 2025-01-24 17:22:13 +01:00
Dhruv Nair
9f0ae2f523 update 2025-01-24 15:15:36 +01:00
Dhruv Nair
403417e926 update 2025-01-24 10:31:26 +01:00
3 changed files with 380 additions and 132 deletions

View File

@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names is not None:
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files
@@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
return True
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
def filter_model_files(filenames):
"""Filter model repo files for just files/folders that contain model weights"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
@@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
legacy_variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
)
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
non_variant_file_re = re.compile(
@@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
if variant is not None:
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
if not ignore_patterns:
return filenames
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
# interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
# all variant filenames will be used by default
usable_filenames = set(variant_filenames)
def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
def find_component(filename):
# Group files by component
components = {}
for filename in filenames:
if not len(filename.split("/")) == 2:
return
component = filename.split("/")[0]
return component
components.setdefault("", []).append(filename)
continue
def has_sharded_variant(component, variant, variant_filenames):
# If component exists check for sharded variant index filename
# If component doesn't exist check main dir for sharded variant index filename
component = component + "/" if component else ""
variant_index_re = re.compile(
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
component, _ = filename.split("/")
components.setdefault(component, []).append(filename)
usable_filenames = set()
variant_filenames = set()
for component, component_filenames in components.items():
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
component_variants = set()
component_legacy_variants = set()
component_non_variants = set()
if variant is not None:
component_variants = filter_with_regex(component_filenames, variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
if component_variants or component_legacy_variants:
variant_filenames.update(
component_variants | component_variant_index_files
if component_variants
else component_legacy_variants | component_legacy_variant_index_files
)
else:
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
usable_filenames.update(component_non_variants | component_variant_index_files)
usable_filenames.update(variant_filenames)
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
raise ValueError(error_message)
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue
component = find_component(filename)
# If a sharded variant exists skip adding to allowed patterns
if has_sharded_variant(component, variant, variant_filenames):
continue
usable_filenames.add(filename)
return usable_filenames, variant_filenames
@@ -922,10 +958,6 @@ def _get_custom_components_and_folders(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)
return custom_components, folder_names
@@ -933,7 +965,6 @@ def _get_ignore_patterns(
passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
@@ -964,16 +995,6 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
@@ -981,16 +1002,6 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)
return ignore_patterns

View File

@@ -89,6 +89,7 @@ from .pipeline_loading_utils import (
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
filter_model_files,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
@@ -1387,10 +1388,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision=revision,
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
use_safetensors = use_safetensors if use_safetensors is not None else True
allow_patterns = None
ignore_patterns = None
@@ -1405,6 +1404,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only:
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
@@ -1419,61 +1430,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
logger.warning(warn_msg)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)
filenames = set(filenames) - set(ignore_filenames)
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant_filenames, variant
pretrained_model_name, config_dict, filenames, variant
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
# also allow downloading generation_config.json of the transformers model
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0
@@ -1506,12 +1476,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
# retrieve the names of the folders containing model weights
model_folder_names = {
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
}
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
model_filenames,
variant_filenames,
filenames,
use_safetensors,
from_flax,
allow_pickle,
@@ -1520,6 +1493,29 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant,
)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)

View File

@@ -212,6 +212,7 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -222,10 +223,13 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_only_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -236,10 +240,13 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
non_variant_file = "text_encoder/model.safetensors"
filenames = [
@@ -249,23 +256,27 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"model.{variant}.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -275,23 +286,76 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
non_variant_file = "model.safetensors"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_sharded_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_sharded_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -302,10 +366,13 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_sharded_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -316,10 +383,49 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
assert model_filenames == variant_filenames
def test_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
"vae/diffusion_pytorch_model.safetensors.index.json",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_sharded_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
@@ -335,9 +441,144 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_downloading_when_no_variant_exists(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_downloading_use_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"text_encoder/model.bin",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(".safetensors" not in f for f in model_filenames)
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "diffusion_pytorch_model.safetensors"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.bin",
"vae/diffusion_pytorch_model.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert {
f"unet/diffusion_pytorch_model.{variant}.bin",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
} == model_filenames
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
allowed_non_variant = "unet"
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_sharded_legacy_variants(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_download_onnx_models(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"vae/model.onnx",
"unet/model.onnx",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
def test_download_flax_models(self):
ignore_patterns = ["*.safetensors", "*.bin"]
filenames = [
"vae/diffusion_flax_model.msgpack",
"unet/diffusion_flax_model.msgpack",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):