mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
4 Commits
variants-f
...
improve-se
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
936e6192b1 | ||
|
|
6a478569f0 | ||
|
|
719d0ce7a7 | ||
|
|
be16b1bcdf |
28
.github/workflows/pr_style_bot.yml
vendored
28
.github/workflows/pr_style_bot.yml
vendored
@@ -64,18 +64,38 @@ jobs:
|
||||
run: |
|
||||
pip install .[quality]
|
||||
|
||||
- name: Download Makefile from main branch
|
||||
- name: Download necessary files from main branch of Diffusers
|
||||
run: |
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
|
||||
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
|
||||
|
||||
- name: Compare Makefiles
|
||||
- name: Compare the files and raise error if needed
|
||||
run: |
|
||||
diff_failed=0
|
||||
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_setup.py setup.py; then
|
||||
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
|
||||
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if [ $diff_failed -eq 1 ]; then
|
||||
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
|
||||
exit 1
|
||||
fi
|
||||
echo "No changes in Makefile. Proceeding..."
|
||||
rm -rf main_Makefile
|
||||
|
||||
echo "No changes in the files. Proceeding..."
|
||||
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
|
||||
|
||||
- name: Run make style and make quality
|
||||
run: |
|
||||
|
||||
19
.github/workflows/pr_tests_gpu.yml
vendored
19
.github/workflows/pr_tests_gpu.yml
vendored
@@ -11,8 +11,6 @@ on:
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/pipelines/test_pipelines_common.py"
|
||||
- "tests/models/test_modeling_common.py"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -106,18 +104,11 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
else
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
fi
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
passed_components = passed_components or []
|
||||
if folder_names:
|
||||
if folder_names is not None:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# extract all components of the pipeline and their associated files
|
||||
@@ -141,25 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
return True
|
||||
|
||||
|
||||
def filter_model_files(filenames):
|
||||
"""Filter model repo files for just files/folders that contain model weights"""
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
@@ -187,10 +169,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
|
||||
legacy_variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
@@ -199,68 +177,54 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
|
||||
if not ignore_patterns:
|
||||
return filenames
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
|
||||
# interested in the extension name
|
||||
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
|
||||
# Group files by component
|
||||
components = {}
|
||||
for filename in filenames:
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
def find_component(filename):
|
||||
if not len(filename.split("/")) == 2:
|
||||
components.setdefault("", []).append(filename)
|
||||
return
|
||||
component = filename.split("/")[0]
|
||||
return component
|
||||
|
||||
def has_sharded_variant(component, variant, variant_filenames):
|
||||
# If component exists check for sharded variant index filename
|
||||
# If component doesn't exist check main dir for sharded variant index filename
|
||||
component = component + "/" if component else ""
|
||||
variant_index_re = re.compile(
|
||||
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
||||
|
||||
for filename in non_variant_filenames:
|
||||
if convert_to_variant(filename) in variant_filenames:
|
||||
continue
|
||||
|
||||
component, _ = filename.split("/")
|
||||
components.setdefault(component, []).append(filename)
|
||||
component = find_component(filename)
|
||||
# If a sharded variant exists skip adding to allowed patterns
|
||||
if has_sharded_variant(component, variant, variant_filenames):
|
||||
continue
|
||||
|
||||
usable_filenames = set()
|
||||
variant_filenames = set()
|
||||
for component, component_filenames in components.items():
|
||||
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
|
||||
|
||||
component_variants = set()
|
||||
component_legacy_variants = set()
|
||||
component_non_variants = set()
|
||||
if variant is not None:
|
||||
component_variants = filter_with_regex(component_filenames, variant_file_re)
|
||||
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
|
||||
|
||||
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
|
||||
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
|
||||
|
||||
if component_variants or component_legacy_variants:
|
||||
variant_filenames.update(
|
||||
component_variants | component_variant_index_files
|
||||
if component_variants
|
||||
else component_legacy_variants | component_legacy_variant_index_files
|
||||
)
|
||||
|
||||
else:
|
||||
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
|
||||
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
|
||||
|
||||
usable_filenames.update(component_non_variants | component_variant_index_files)
|
||||
|
||||
usable_filenames.update(variant_filenames)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
|
||||
raise ValueError(error_message)
|
||||
|
||||
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
usable_filenames.add(filename)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
@@ -958,6 +922,10 @@ def _get_custom_components_and_folders(
|
||||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
return custom_components, folder_names
|
||||
|
||||
|
||||
@@ -965,6 +933,7 @@ def _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names: List[str],
|
||||
model_filenames: List[str],
|
||||
variant_filenames: List[str],
|
||||
use_safetensors: bool,
|
||||
from_flax: bool,
|
||||
allow_pickle: bool,
|
||||
@@ -995,6 +964,16 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
@@ -1002,6 +981,16 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
|
||||
f"your folder structure."
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ from .pipeline_loading_utils import (
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
filter_model_files,
|
||||
load_sub_model,
|
||||
maybe_raise_or_warn,
|
||||
variant_compatible_siblings,
|
||||
@@ -1388,8 +1387,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
|
||||
use_safetensors = use_safetensors if use_safetensors is not None else True
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
allow_patterns = None
|
||||
ignore_patterns = None
|
||||
@@ -1404,18 +1405,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
if not local_files_only:
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
)
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
||||
warn_msg = (
|
||||
@@ -1430,20 +1419,61 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
filenames = set(filenames) - set(ignore_filenames)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.22.0"):
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
||||
|
||||
custom_components, folder_names = _get_custom_components_and_folders(
|
||||
pretrained_model_name, config_dict, filenames, variant
|
||||
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
||||
)
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
|
||||
custom_pipeline = config_dict["_class_name"][0]
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
# also allow downloading generation_config.json of the transformers model
|
||||
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
|
||||
load_components_from_hub = len(custom_components) > 0
|
||||
|
||||
@@ -1476,15 +1506,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
||||
passed_components = [k for k in expected_components if k in kwargs]
|
||||
|
||||
# retrieve the names of the folders containing model weights
|
||||
model_folder_names = {
|
||||
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
|
||||
}
|
||||
# retrieve all patterns that should not be downloaded and error out when needed
|
||||
ignore_patterns = _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names,
|
||||
filenames,
|
||||
model_filenames,
|
||||
variant_filenames,
|
||||
use_safetensors,
|
||||
from_flax,
|
||||
allow_pickle,
|
||||
@@ -1493,29 +1520,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant,
|
||||
)
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
# Don't download any objects that are passed
|
||||
allow_patterns = [
|
||||
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
|
||||
|
||||
@@ -1169,16 +1169,17 @@ class ModelTesterMixin:
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
# This errors out because it's missing an offload folder
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = OmniGenTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
model_split_percents = [0.1, 0.1, 0.1]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -74,9 +73,9 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 32,
|
||||
"num_layers": 20,
|
||||
"num_layers": 1,
|
||||
"pad_token_id": 0,
|
||||
"vocab_size": 1000,
|
||||
"vocab_size": 100,
|
||||
"in_channels": 4,
|
||||
"time_step_dim": 4,
|
||||
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
|
||||
|
||||
@@ -33,7 +33,6 @@ enable_full_determinism()
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -68,7 +67,7 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
@@ -108,7 +107,6 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -143,7 +141,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
|
||||
@@ -212,7 +212,6 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
|
||||
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
def test_only_non_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -223,13 +222,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_only_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -240,13 +236,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
non_variant_file = "text_encoder/model.safetensors"
|
||||
filenames = [
|
||||
@@ -256,27 +249,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_non_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"model.{variant}.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -286,76 +275,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
non_variant_file = "model.safetensors"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_non_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
@@ -366,13 +302,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
@@ -383,49 +316,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
assert model_filenames == variant_filenames
|
||||
|
||||
def test_single_variant_with_sharded_non_variant_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_mixed_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
@@ -441,144 +335,9 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_downloading_when_no_variant_exists(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
|
||||
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
def test_downloading_use_safetensors_false(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
filenames = [
|
||||
"text_encoder/model.bin",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
assert all(".safetensors" not in f for f in model_filenames)
|
||||
|
||||
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "diffusion_pytorch_model.safetensors"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_download_variants_when_component_has_no_safetensors_variant(self):
|
||||
ignore_patterns = None
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.{variant}.bin",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert {
|
||||
f"unet/diffusion_pytorch_model.{variant}.bin",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
} == model_filenames
|
||||
|
||||
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
|
||||
]
|
||||
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
allowed_non_variant = "unet"
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_download_sharded_legacy_variants(self):
|
||||
ignore_patterns = None
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_download_onnx_models(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
filenames = [
|
||||
"vae/model.onnx",
|
||||
"unet/model.onnx",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert model_filenames == set(filenames)
|
||||
|
||||
def test_download_flax_models(self):
|
||||
ignore_patterns = ["*.safetensors", "*.bin"]
|
||||
filenames = [
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert model_filenames == set(filenames)
|
||||
|
||||
|
||||
class ProgressBarTests(unittest.TestCase):
|
||||
def get_dummy_components_image_generation(self):
|
||||
|
||||
Reference in New Issue
Block a user