mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
6 Commits
pr-tests-f
...
shared-var
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ab5c07fcf | ||
|
|
b0caf4169f | ||
|
|
1c73b445d1 | ||
|
|
f3b76fb430 | ||
|
|
ed0a6d70c5 | ||
|
|
7b668b1de1 |
@@ -198,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = convert_to_variant(f)
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
def find_component(filename):
|
||||
if not len(filename.split("/")) == 2:
|
||||
return
|
||||
component = filename.split("/")[0]
|
||||
return component
|
||||
|
||||
def has_sharded_variant(component, variant, variant_filenames):
|
||||
# If component exists check for sharded variant index filename
|
||||
# If component doesn't exist check main dir for sharded variant index filename
|
||||
component = component + "/" if component else ""
|
||||
variant_index_re = re.compile(
|
||||
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
||||
|
||||
for filename in non_variant_filenames:
|
||||
if convert_to_variant(filename) in variant_filenames:
|
||||
continue
|
||||
|
||||
component = find_component(filename)
|
||||
# If a sharded variant exists skip adding to allowed patterns
|
||||
if has_sharded_variant(component, variant, variant_filenames):
|
||||
continue
|
||||
|
||||
usable_filenames.add(filename)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
|
||||
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
@@ -210,6 +210,135 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
|
||||
|
||||
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
def test_only_non_variants_downloaded(self):
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
f"text_encoder/model.{variant}.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_only_variants_downloaded(self):
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
f"text_encoder/model.{variant}.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_downloaded(self):
|
||||
variant = "fp16"
|
||||
non_variant_file = "text_encoder/model.safetensors"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_non_variants_in_main_dir_downloaded(self):
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"model.{variant}.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_variants_in_main_dir_downloaded(self):
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"model.{variant}.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_in_main_dir_downloaded(self):
|
||||
variant = "fp16"
|
||||
non_variant_file = "model.safetensors"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_non_variants_downloaded(self):
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_downloaded(self):
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_sharded_mixed_variants_downloaded(self):
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
|
||||
class ProgressBarTests(unittest.TestCase):
|
||||
def get_dummy_components_image_generation(self):
|
||||
cross_attention_dim = 8
|
||||
|
||||
Reference in New Issue
Block a user