mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-13 16:04:41 +08:00
Compare commits
6 Commits
quantizer-
...
shared-var
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ab5c07fcf | ||
|
|
b0caf4169f | ||
|
|
1c73b445d1 | ||
|
|
f3b76fb430 | ||
|
|
ed0a6d70c5 | ||
|
|
7b668b1de1 |
@@ -198,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
|||||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||||
return variant_filename
|
return variant_filename
|
||||||
|
|
||||||
for f in non_variant_filenames:
|
def find_component(filename):
|
||||||
variant_filename = convert_to_variant(f)
|
if not len(filename.split("/")) == 2:
|
||||||
if variant_filename not in usable_filenames:
|
return
|
||||||
usable_filenames.add(f)
|
component = filename.split("/")[0]
|
||||||
|
return component
|
||||||
|
|
||||||
|
def has_sharded_variant(component, variant, variant_filenames):
|
||||||
|
# If component exists check for sharded variant index filename
|
||||||
|
# If component doesn't exist check main dir for sharded variant index filename
|
||||||
|
component = component + "/" if component else ""
|
||||||
|
variant_index_re = re.compile(
|
||||||
|
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||||
|
)
|
||||||
|
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
||||||
|
|
||||||
|
for filename in non_variant_filenames:
|
||||||
|
if convert_to_variant(filename) in variant_filenames:
|
||||||
|
continue
|
||||||
|
|
||||||
|
component = find_component(filename)
|
||||||
|
# If a sharded variant exists skip adding to allowed patterns
|
||||||
|
if has_sharded_variant(component, variant, variant_filenames):
|
||||||
|
continue
|
||||||
|
|
||||||
|
usable_filenames.add(filename)
|
||||||
|
|
||||||
return usable_filenames, variant_filenames
|
return usable_filenames, variant_filenames
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from diffusers import (
|
|||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
|
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
|
||||||
from diffusers.utils.testing_utils import torch_device
|
from diffusers.utils.testing_utils import torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -210,6 +210,135 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
|||||||
self.assertFalse(is_safetensors_compatible(filenames))
|
self.assertFalse(is_safetensors_compatible(filenames))
|
||||||
|
|
||||||
|
|
||||||
|
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||||
|
def test_only_non_variants_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
filenames = [
|
||||||
|
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors",
|
||||||
|
f"text_encoder/model.{variant}.safetensors",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||||
|
assert all(variant not in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_only_variants_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
filenames = [
|
||||||
|
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors",
|
||||||
|
f"text_encoder/model.{variant}.safetensors",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
assert all(variant in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_mixed_variants_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
non_variant_file = "text_encoder/model.safetensors"
|
||||||
|
filenames = [
|
||||||
|
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_non_variants_in_main_dir_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
filenames = [
|
||||||
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
"model.safetensors",
|
||||||
|
f"model.{variant}.safetensors",
|
||||||
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||||
|
assert all(variant not in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_variants_in_main_dir_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
filenames = [
|
||||||
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
"model.safetensors",
|
||||||
|
f"model.{variant}.safetensors",
|
||||||
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
assert all(variant in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_mixed_variants_in_main_dir_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
non_variant_file = "model.safetensors"
|
||||||
|
filenames = [
|
||||||
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
"model.safetensors",
|
||||||
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_sharded_non_variants_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
filenames = [
|
||||||
|
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||||
|
assert all(variant not in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_sharded_variants_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
filenames = [
|
||||||
|
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||||
|
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
assert all(variant in f for f in model_filenames)
|
||||||
|
|
||||||
|
def test_sharded_mixed_variants_downloaded(self):
|
||||||
|
variant = "fp16"
|
||||||
|
allowed_non_variant = "unet"
|
||||||
|
filenames = [
|
||||||
|
f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||||
|
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
]
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||||
|
|
||||||
|
|
||||||
class ProgressBarTests(unittest.TestCase):
|
class ProgressBarTests(unittest.TestCase):
|
||||||
def get_dummy_components_image_generation(self):
|
def get_dummy_components_image_generation(self):
|
||||||
cross_attention_dim = 8
|
cross_attention_dim = 8
|
||||||
|
|||||||
Reference in New Issue
Block a user