Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
936e6192b1 Merge branch 'main' into improve-security-stylebot 2025-02-27 09:09:52 +05:30
sayakpaul
6a478569f0 2025-02-26 15:34:34 +05:30
Sayak Paul
719d0ce7a7 Merge branch 'main' into improve-security-stylebot 2025-02-26 15:32:22 +05:30
sayakpaul
be16b1bcdf improve security for the stylebot. 2025-02-26 15:30:08 +05:30
8 changed files with 169 additions and 408 deletions

View File

@@ -64,18 +64,38 @@ jobs:
run: | run: |
pip install .[quality] pip install .[quality]
- name: Download Makefile from main branch - name: Download necessary files from main branch of Diffusers
run: | run: |
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
- name: Compare Makefiles - name: Compare the files and raise error if needed
run: | run: |
diff_failed=0
if ! diff -q main_Makefile Makefile; then if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch." echo "Error: The Makefile has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_setup.py setup.py; then
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if [ $diff_failed -eq 1 ]; then
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
exit 1 exit 1
fi fi
echo "No changes in Makefile. Proceeding..."
rm -rf main_Makefile echo "No changes in the files. Proceeding..."
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
- name: Run make style and make quality - name: Run make style and make quality
run: | run: |

View File

@@ -11,8 +11,6 @@ on:
- "src/diffusers/loaders/lora_base.py" - "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py" - "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py" - "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
@@ -106,18 +104,11 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8 CUBLAS_WORKSPACE_CONFIG: :16:8
run: | run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \ -s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }} tests/pipelines/${{ matrix.module }}
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
- name: Failure short reports - name: Failure short reports
if: ${{ failure() }} if: ${{ failure() }}

View File

@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
extension is replaced with ".safetensors" extension is replaced with ".safetensors"
""" """
passed_components = passed_components or [] passed_components = passed_components or []
if folder_names: if folder_names is not None:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files # extract all components of the pipeline and their associated files
@@ -141,25 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
return True return True
def filter_model_files(filenames): def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
"""Filter model repo files for just files/folders that contain model weights"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [ weight_names = [
WEIGHTS_NAME, WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
@@ -187,10 +169,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
variant_index_re = re.compile( variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
) )
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
legacy_variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
)
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
non_variant_file_re = re.compile( non_variant_file_re = re.compile(
@@ -199,68 +177,54 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
# `text_encoder/pytorch_model.bin.index.json` # `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
def filter_for_compatible_extensions(filenames, ignore_patterns=None): if variant is not None:
if not ignore_patterns: variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
return filenames variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()
# ignore patterns uses glob style patterns e.g *.safetensors but we're only non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
# interested in the extension name non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} non_variant_filenames = non_variant_weights | non_variant_indexes
def filter_with_regex(filenames, pattern_re): # all variant filenames will be used by default
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} usable_filenames = set(variant_filenames)
# Group files by component def convert_to_variant(filename):
components = {} if "index" in filename:
for filename in filenames: variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
def find_component(filename):
if not len(filename.split("/")) == 2: if not len(filename.split("/")) == 2:
components.setdefault("", []).append(filename) return
component = filename.split("/")[0]
return component
def has_sharded_variant(component, variant, variant_filenames):
# If component exists check for sharded variant index filename
# If component doesn't exist check main dir for sharded variant index filename
component = component + "/" if component else ""
variant_index_re = re.compile(
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue continue
component, _ = filename.split("/") component = find_component(filename)
components.setdefault(component, []).append(filename) # If a sharded variant exists skip adding to allowed patterns
if has_sharded_variant(component, variant, variant_filenames):
continue
usable_filenames = set() usable_filenames.add(filename)
variant_filenames = set()
for component, component_filenames in components.items():
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
component_variants = set()
component_legacy_variants = set()
component_non_variants = set()
if variant is not None:
component_variants = filter_with_regex(component_filenames, variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
if component_variants or component_legacy_variants:
variant_filenames.update(
component_variants | component_variant_index_files
if component_variants
else component_legacy_variants | component_legacy_variant_index_files
)
else:
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
usable_filenames.update(component_non_variants | component_variant_index_files)
usable_filenames.update(variant_filenames)
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
raise ValueError(error_message)
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)
return usable_filenames, variant_filenames return usable_filenames, variant_filenames
@@ -958,6 +922,10 @@ def _get_custom_components_and_folders(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
) )
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)
return custom_components, folder_names return custom_components, folder_names
@@ -965,6 +933,7 @@ def _get_ignore_patterns(
passed_components, passed_components,
model_folder_names: List[str], model_folder_names: List[str],
model_filenames: List[str], model_filenames: List[str],
variant_filenames: List[str],
use_safetensors: bool, use_safetensors: bool,
from_flax: bool, from_flax: bool,
allow_pickle: bool, allow_pickle: bool,
@@ -995,6 +964,16 @@ def _get_ignore_patterns(
if not use_onnx: if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"] ignore_patterns += ["*.onnx", "*.pb"]
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)
else: else:
ignore_patterns = ["*.safetensors", "*.msgpack"] ignore_patterns = ["*.safetensors", "*.msgpack"]
@@ -1002,6 +981,16 @@ def _get_ignore_patterns(
if not use_onnx: if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"] ignore_patterns += ["*.onnx", "*.pb"]
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)
return ignore_patterns return ignore_patterns

View File

@@ -89,7 +89,6 @@ from .pipeline_loading_utils import (
_resolve_custom_pipeline_and_cls, _resolve_custom_pipeline_and_cls,
_unwrap_model, _unwrap_model,
_update_init_kwargs_with_connected_pipeline, _update_init_kwargs_with_connected_pipeline,
filter_model_files,
load_sub_model, load_sub_model,
maybe_raise_or_warn, maybe_raise_or_warn,
variant_compatible_siblings, variant_compatible_siblings,
@@ -1388,8 +1387,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision=revision, revision=revision,
) )
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False allow_pickle = False
use_safetensors = use_safetensors if use_safetensors is not None else True if use_safetensors is None:
use_safetensors = True
allow_pickle = True
allow_patterns = None allow_patterns = None
ignore_patterns = None ignore_patterns = None
@@ -1404,18 +1405,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_info_call_error = e # save error to reraise it if model is not cached locally model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only: if not local_files_only:
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
filenames = {sibling.rfilename for sibling in info.siblings} filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = ( warn_msg = (
@@ -1430,20 +1419,61 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
logger.warning(warn_msg) logger.warning(warn_msg)
filenames = set(filenames) - set(ignore_filenames) model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)
if revision in DEPRECATED_REVISION_ARGS and version.parse( if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version version.parse(__version__).base_version
) >= version.parse("0.22.0"): ) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames) warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
custom_components, folder_names = _get_custom_components_and_folders( custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant pretrained_model_name, config_dict, filenames, variant_filenames, variant
) )
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
custom_class_name = None custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0] custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1] custom_class_name = config_dict["_class_name"][1]
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
# also allow downloading generation_config.json of the transformers model
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0 load_components_from_hub = len(custom_components) > 0
@@ -1476,15 +1506,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_components, _ = cls._get_signature_keys(pipeline_class) expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs] passed_components = [k for k in expected_components if k in kwargs]
# retrieve the names of the folders containing model weights
model_folder_names = {
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
}
# retrieve all patterns that should not be downloaded and error out when needed # retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns( ignore_patterns = _get_ignore_patterns(
passed_components, passed_components,
model_folder_names, model_folder_names,
filenames, model_filenames,
variant_filenames,
use_safetensors, use_safetensors,
from_flax, from_flax,
allow_pickle, allow_pickle,
@@ -1493,29 +1520,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant, variant,
) )
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
# Don't download any objects that are passed # Don't download any objects that are passed
allow_patterns = [ allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)

View File

@@ -1169,16 +1169,17 @@ class ModelTesterMixin:
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False) model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
# This errors out because it's missing an offload folder # This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained( new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
) )

View File

@@ -30,7 +30,6 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = OmniGenTransformer2DModel model_class = OmniGenTransformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
model_split_percents = [0.1, 0.1, 0.1]
@property @property
def dummy_input(self): def dummy_input(self):
@@ -74,9 +73,9 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
"num_attention_heads": 4, "num_attention_heads": 4,
"num_key_value_heads": 4, "num_key_value_heads": 4,
"intermediate_size": 32, "intermediate_size": 32,
"num_layers": 20, "num_layers": 1,
"pad_token_id": 0, "pad_token_id": 0,
"vocab_size": 1000, "vocab_size": 100,
"in_channels": 4, "in_channels": 4,
"time_step_dim": 4, "time_step_dim": 4,
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))}, "rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},

View File

@@ -33,7 +33,6 @@ enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel model_class = SD3Transformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@property @property
def dummy_input(self): def dummy_input(self):
@@ -68,7 +67,7 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
"num_layers": 4, "num_layers": 1,
"attention_head_dim": 8, "attention_head_dim": 8,
"num_attention_heads": 4, "num_attention_heads": 4,
"caption_projection_dim": 32, "caption_projection_dim": 32,
@@ -108,7 +107,6 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel model_class = SD3Transformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@property @property
def dummy_input(self): def dummy_input(self):
@@ -143,7 +141,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
"num_layers": 4, "num_layers": 2,
"attention_head_dim": 8, "attention_head_dim": 8,
"num_attention_heads": 4, "num_attention_heads": 4,
"caption_projection_dim": 32, "caption_projection_dim": 32,

View File

@@ -212,7 +212,6 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
class VariantCompatibleSiblingsTest(unittest.TestCase): class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self): def test_only_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
filenames = [ filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors", f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -223,13 +222,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"unet/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames) assert all(variant not in f for f in model_filenames)
def test_only_variants_downloaded(self): def test_only_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
filenames = [ filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors", f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -240,13 +236,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"unet/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames) assert all(variant in f for f in model_filenames)
def test_mixed_variants_downloaded(self): def test_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
non_variant_file = "text_encoder/model.safetensors" non_variant_file = "text_encoder/model.safetensors"
filenames = [ filenames = [
@@ -256,27 +249,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}.safetensors", f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_non_variants_in_main_dir_downloaded(self): def test_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
filenames = [ filenames = [
f"diffusion_pytorch_model.{variant}.safetensors", f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.safetensors",
"model.safetensors", "model.safetensors",
f"model.{variant}.safetensors", f"model.{variant}.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames) assert all(variant not in f for f in model_filenames)
def test_variants_in_main_dir_downloaded(self): def test_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
filenames = [ filenames = [
f"diffusion_pytorch_model.{variant}.safetensors", f"diffusion_pytorch_model.{variant}.safetensors",
@@ -286,76 +275,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"diffusion_pytorch_model.{variant}.safetensors", f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames) assert all(variant in f for f in model_filenames)
def test_mixed_variants_in_main_dir_downloaded(self): def test_mixed_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
non_variant_file = "model.safetensors" non_variant_file = "model.safetensors"
filenames = [ filenames = [
f"diffusion_pytorch_model.{variant}.safetensors", f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.safetensors",
"model.safetensors", "model.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_sharded_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
"diffusion_pytorch_model.safetensors.index.json",
"diffusion_pytorch_model-00001-of-00003.safetensors",
"diffusion_pytorch_model-00002-of-00003.safetensors",
"diffusion_pytorch_model-00003-of-00003.safetensors",
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames)
def test_sharded_non_variants_downloaded(self): def test_sharded_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
filenames = [ filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -366,13 +302,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(variant not in f for f in model_filenames) assert all(variant not in f for f in model_filenames)
def test_sharded_variants_downloaded(self): def test_sharded_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
filenames = [ filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -383,49 +316,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames) assert all(variant in f for f in model_filenames)
assert model_filenames == variant_filenames
def test_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
"vae/diffusion_pytorch_model.safetensors.index.json",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_sharded_mixed_variants_downloaded(self): def test_sharded_mixed_variants_downloaded(self):
ignore_patterns = ["*.bin"]
variant = "fp16" variant = "fp16"
allowed_non_variant = "unet" allowed_non_variant = "unet"
filenames = [ filenames = [
@@ -441,144 +335,9 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
"vae/diffusion_pytorch_model-00002-of-00003.safetensors", "vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors", "vae/diffusion_pytorch_model-00003-of-00003.safetensors",
] ]
model_filenames, variant_filenames = variant_compatible_siblings( model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_downloading_when_no_variant_exists(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_downloading_use_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"text_encoder/model.bin",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert all(".safetensors" not in f for f in model_filenames)
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
allowed_non_variant = "diffusion_pytorch_model.safetensors"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.{variant}.bin",
"vae/diffusion_pytorch_model.safetensors",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert {
f"unet/diffusion_pytorch_model.{variant}.bin",
f"vae/diffusion_pytorch_model.{variant}.safetensors",
} == model_filenames
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
ignore_patterns = ["*.bin"]
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
ignore_patterns = ["*.safetensors"]
allowed_non_variant = "unet"
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
def test_download_sharded_legacy_variants(self):
ignore_patterns = None
variant = "fp16"
filenames = [
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
assert all(variant in f for f in model_filenames)
def test_download_onnx_models(self):
ignore_patterns = ["*.safetensors"]
filenames = [
"vae/model.onnx",
"unet/model.onnx",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
def test_download_flax_models(self):
ignore_patterns = ["*.safetensors", "*.bin"]
filenames = [
"vae/diffusion_flax_model.msgpack",
"unet/diffusion_flax_model.msgpack",
]
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=None, ignore_patterns=ignore_patterns
)
assert model_filenames == set(filenames)
class ProgressBarTests(unittest.TestCase): class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self): def get_dummy_components_image_generation(self):