mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 04:44:46 +08:00
Compare commits
2 Commits
remove-unn
...
safetensor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
679e0958a2 | ||
|
|
bfc66f8aa0 |
@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
components[component].append(component_filename)
|
||||
|
||||
# If there are no component folders check the main directory for safetensors files
|
||||
filtered_filenames = set()
|
||||
if not components:
|
||||
if variant is not None:
|
||||
filtered_filenames = filter_with_regex(filenames, variant_file_re)
|
||||
else:
|
||||
|
||||
# If no variant filenames exist check if non-variant files are available
|
||||
if not filtered_filenames:
|
||||
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
|
||||
return any(".safetensors" in filename for filename in filtered_filenames)
|
||||
|
||||
# iterate over all files of a component
|
||||
# check if safetensor files exist for that component
|
||||
# if variant is provided check if the variant of the safetensors exists
|
||||
for component, component_filenames in components.items():
|
||||
matches = []
|
||||
filtered_component_filenames = set()
|
||||
# if variant is provided check if the variant of the safetensors exists
|
||||
if variant is not None:
|
||||
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
|
||||
else:
|
||||
|
||||
# if variant safetensor files do not exist check for non-variants
|
||||
if not filtered_component_filenames:
|
||||
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
|
||||
for component_filename in filtered_component_filenames:
|
||||
filename, extension = os.path.splitext(component_filename)
|
||||
|
||||
@@ -217,6 +217,20 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_is_compatible_mixed_variants(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_is_compatible_variant_and_non_safetensors(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
|
||||
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
def test_only_non_variants_downloaded(self):
|
||||
|
||||
@@ -538,38 +538,26 @@ class DownloadTests(unittest.TestCase):
|
||||
variant = "no_ema"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
if use_safetensors:
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
|
||||
else:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
|
||||
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
|
||||
|
||||
# Some of the downloaded files should be a non-variant file, check:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
|
||||
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
|
||||
# vae, safety_checker and text_encoder should have no variant
|
||||
assert (
|
||||
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
)
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
# Some of the downloaded files should be a non-variant file, check:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
|
||||
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
|
||||
# vae, safety_checker and text_encoder should have no variant
|
||||
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
def test_download_variants_with_sharded_checkpoints(self):
|
||||
# Here we test for downloading of "variant" files belonging to the `unet` and
|
||||
|
||||
Reference in New Issue
Block a user