Compare commits

...

6 Commits

Author SHA1 Message Date
Dhruv Nair
1ab5c07fcf Merge branch 'shared-variants-downloads' of https://github.com/huggingface/diffusers into shared-variants-downloads 2024-11-08 12:38:17 +01:00
Dhruv Nair
b0caf4169f update 2024-11-08 12:36:29 +01:00
Sayak Paul
1c73b445d1 Merge branch 'main' into shared-variants-downloads 2024-11-07 02:51:36 +01:00
Dhruv Nair
f3b76fb430 update 2024-11-06 19:47:15 +01:00
Dhruv Nair
ed0a6d70c5 update 2024-11-05 20:50:29 +01:00
Dhruv Nair
7b668b1de1 update 2024-11-05 18:54:03 +01:00
2 changed files with 155 additions and 5 deletions

View File

@@ -198,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
for f in non_variant_filenames:
variant_filename = convert_to_variant(f)
if variant_filename not in usable_filenames:
usable_filenames.add(f)
def find_component(filename):
if not len(filename.split("/")) == 2:
return
component = filename.split("/")[0]
return component
def has_sharded_variant(component, variant, variant_filenames):
# If component exists check for sharded variant index filename
# If component doesn't exist check main dir for sharded variant index filename
component = component + "/" if component else ""
variant_index_re = re.compile(
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue
component = find_component(filename)
# If a sharded variant exists skip adding to allowed patterns
if has_sharded_variant(component, variant, variant_filenames):
continue
usable_filenames.add(filename)
return usable_filenames, variant_filenames

View File

@@ -18,7 +18,7 @@ from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
from diffusers.utils.testing_utils import torch_device
@@ -210,6 +210,135 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
self.assertFalse(is_safetensors_compatible(filenames))
class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"vae/diffusion_pytorch_model.safetensors",
f"text_encoder/model.{variant}.safetensors",
"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
assert all(variant not in f for f in model_filenames)
def test_only_variants_downloaded(self):
variant = "fp16"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"vae/diffusion_pytorch_model.safetensors",
f"text_encoder/model.{variant}.safetensors",
"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_downloaded(self):
variant = "fp16"
non_variant_file = "text_encoder/model.safetensors"
filenames = [
f"vae/diffusion_pytorch_model.{variant}.safetensors",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/model.safetensors",
f"unet/diffusion_pytorch_model.{variant}.safetensors",
"unet/diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_non_variants_in_main_dir_downloaded(self):
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"model.{variant}.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
assert all(variant not in f for f in model_filenames)
def test_variants_in_main_dir_downloaded(self):
variant = "fp16"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"model.{variant}.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f for f in model_filenames)
def test_mixed_variants_in_main_dir_downloaded(self):
variant = "fp16"
non_variant_file = "model.safetensors"
filenames = [
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
"model.safetensors",
f"diffusion_pytorch_model.{variant}.safetensors",
"diffusion_pytorch_model.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
def test_sharded_non_variants_downloaded(self):
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
assert all(variant not in f for f in model_filenames)
def test_sharded_variants_downloaded(self):
variant = "fp16"
filenames = [
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f for f in model_filenames)
def test_sharded_mixed_variants_downloaded(self):
variant = "fp16"
allowed_non_variant = "unet"
filenames = [
f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
"vae/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model.safetensors.index.json",
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
]
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
cross_attention_dim = 8