Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
936e6192b1 Merge branch 'main' into improve-security-stylebot 2025-02-27 09:09:52 +05:30
sayakpaul
6a478569f0 2025-02-26 15:34:34 +05:30
Sayak Paul
719d0ce7a7 Merge branch 'main' into improve-security-stylebot 2025-02-26 15:32:22 +05:30
sayakpaul
be16b1bcdf improve security for the stylebot. 2025-02-26 15:30:08 +05:30
8 changed files with 169 additions and 408 deletions

View File

@@ -64,18 +64,38 @@ jobs:
run: |
pip install .[quality]
- name: Download Makefile from main branch
- name: Download necessary files from main branch of Diffusers
run: |
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
- name: Compare Makefiles
- name: Compare the files and raise error if needed
run: |
diff_failed=0
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_setup.py setup.py; then
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if [ $diff_failed -eq 1 ]; then
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
exit 1
fi
echo "No changes in Makefile. Proceeding..."
rm -rf main_Makefile
echo "No changes in the files. Proceeding..."
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
- name: Run make style and make quality
run: |

View File

@@ -11,8 +11,6 @@ on:
- "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
workflow_dispatch:
concurrency:
@@ -106,18 +104,11 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
if: ${{ failure() }}

View File

@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names:
if folder_names is not None:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files
@@ -141,25 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
return True
def filter_model_files(filenames):
"""Filter model repo files for just files/folders that contain model weights"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
@@ -187,10 +169,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
legacy_variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
)
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
non_variant_file_re = re.compile(
@@ -199,68 +177,54 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
if not ignore_patterns:
return filenames
if variant is not None:
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
# interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes
def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
# all variant filenames will be used by default
usable_filenames = set(variant_filenames)
# Group files by component
components = {}
for filename in filenames:
def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
def find_component(filename):
if not len(filename.split("/")) == 2:
components.setdefault("", []).append(filename)
return
component = filename.split("/")[0]
return component
def has_sharded_variant(component, variant, variant_filenames):
# If component exists check for sharded variant index filename
# If component doesn't exist check main dir for sharded variant index filename
component = component + "/" if component else ""
variant_index_re = re.compile(
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue
component, _ = filename.split("/")
components.setdefault(component, []).append(filename)
component = find_component(filename)
# If a sharded variant exists skip adding to allowed patterns
if has_sharded_variant(component, variant, variant_filenames):
continue
usable_filenames = set()
variant_filenames = set()
for component, component_filenames in components.items():
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
component_variants = set()
component_legacy_variants = set()
component_non_variants = set()
if variant is not None:
component_variants = filter_with_regex(component_filenames, variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
if component_variants or component_legacy_variants:
variant_filenames.update(
component_variants | component_variant_index_files
if component_variants
else component_legacy_variants | component_legacy_variant_index_files
)
else:
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
usable_filenames.update(component_non_variants | component_variant_index_files)
usable_filenames.update(variant_filenames)
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
raise ValueError(error_message)
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)
usable_filenames.add(filename)
return usable_filenames, variant_filenames
@@ -958,6 +922,10 @@ def _get_custom_components_and_folders(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)
return custom_components, folder_names
@@ -965,6 +933,7 @@ def _get_ignore_patterns(
passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
@@ -995,6 +964,16 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
@@ -1002,6 +981,16 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)
return ignore_patterns

View File

@@ -89,7 +89,6 @@ from .pipeline_loading_utils import (
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
filter_model_files,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
@@ -1388,8 +1387,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision=revision,
)
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
use_safetensors = use_safetensors if use_safetensors is not None else True
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
allow_patterns = None
ignore_patterns = None
@@ -1404,18 +1405,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only:
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
@@ -1430,20 +1419,61 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
logger.warning(warn_msg)
filenames = set(filenames) - set(ignore_filenames)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant
pretrained_model_name, config_dict, filenames, variant_filenames, variant
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
# also allow downloading generation_config.json of the transformers model
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0
@@ -1476,15 +1506,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
# retrieve the names of the folders containing model weights
model_folder_names = {
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
}
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
filenames,
model_filenames,
variant_filenames,
use_safetensors,
from_flax,
allow_pickle,
@@ -1493,29 +1520,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant,
)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)

View File

@@ -1169,16 +1169,17 @@ class ModelTesterMixin:
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
# This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)

View File

@@ -30,7 +30,6 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = OmniGenTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.1, 0.1, 0.1]
@property
def dummy_input(self):
@@ -74,9 +73,9 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
"num_attention_heads": 4,
"num_key_value_heads": 4,
"intermediate_size": 32,
"num_layers": 20,
"num_layers": 1,
"pad_token_id": 0,
"vocab_size": 1000,
"vocab_size": 100,
"in_channels": 4,
"time_step_dim": 4,
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},

View File

@@ -33,7 +33,6 @@ enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@property
def dummy_input(self):
@@ -68,7 +67,7 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
"num_layers": 4,
"num_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
@@ -108,7 +107,6 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@property
def dummy_input(self):
@@ -143,7 +141,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
"num_layers": 4,
"num_layers": 2,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,

View File

@@ -212,7 +212,6 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -223,13 +222,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
assert all(variant not in f for f in model_filenames)
def test_only_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -240,13 +236,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
non_variant_file = "text_encoder/model.safetensors"
filenames = [
@@ -256,27 +249,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"model.{variant}.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
assert all(variant not in f for f in model_filenames)
def test_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -286,76 +275,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
non_variant_file = "model.safetensors"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_sharded_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_sharded_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -366,13 +302,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
assert all(variant not in f for f in model_filenames)
def test_sharded_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -383,49 +316,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f for f in model_filenames)
assert model_filenames == variant_filenames
def test_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
"vae/diffusion_pytorch_model.safetensors.index.json",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_sharded_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
@@ -441,144 +335,9 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_downloading_when_no_variant_exists(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_downloading_use_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"text_encoder/model.bin",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(".safetensors" not in f for f in model_filenames)
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "diffusion_pytorch_model.safetensors"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.bin",
"vae/diffusion_pytorch_model.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert {
f"unet/diffusion_pytorch_model.{variant}.bin",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
} == model_filenames
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
allowed_non_variant = "unet"
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_sharded_legacy_variants(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_download_onnx_models(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"vae/model.onnx",
"unet/model.onnx",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
def test_download_flax_models(self):
ignore_patterns = ["*.safetensors", "*.bin"]
filenames = [
"vae/diffusion_flax_model.msgpack",
"unet/diffusion_flax_model.msgpack",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):