Compare commits

...

8 Commits

Author SHA1 Message Date
Dhruv Nair
ca827244c1 Merge branch 'safetensors-variant-loading' of https://github.com/huggingface/diffusers into safetensors-variant-loading 2024-08-07 05:05:40 +00:00
Dhruv Nair
7d6f0444d6 Merge https://github.com/huggingface/diffusers into safetensors-variant-loading 2024-08-07 05:04:15 +00:00
Dhruv Nair
8acb5e90b4 update 2024-08-07 05:03:56 +00:00
Dhruv Nair
7229d46041 Merge branch 'main' into safetensors-variant-loading 2024-07-30 18:24:15 +05:30
Dhruv Nair
9675137bd2 update 2024-07-29 12:53:50 +00:00
Dhruv Nair
3a41ab315a update 2024-07-26 16:12:19 +00:00
Dhruv Nair
922689bc0a update 2024-07-26 14:31:56 +00:00
Dhruv Nair
fc8d362e50 update 2024-07-26 08:30:10 +00:00
4 changed files with 152 additions and 83 deletions

View File

@@ -89,49 +89,44 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
- The model is safetensors compatible only if there is a safetensors file for each model component present in
filenames.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames = []
sf_filenames = set()
passed_components = passed_components or []
# extract all components of the pipeline and their associated files
components = {}
for filename in filenames:
_, extension = os.path.splitext(filename)
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
if not len(filename.split("/")) == 2:
continue
if extension == ".bin":
pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
sf_filenames.add(os.path.normpath(filename))
component, component_filename = filename.split("/")
if component in passed_components:
continue
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)
components.setdefault(component, [])
components[component].append(component_filename)
if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
for component_filename in component_filenames:
filename, extension = os.path.splitext(component_filename)
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
match_exists = extension == ".safetensors"
matches.append(match_exists)
if not any(matches):
return False
return True

View File

@@ -1416,18 +1416,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
)
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
):
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx

View File

@@ -68,25 +68,21 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
@@ -99,25 +95,14 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin",
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible_variant(self):
filenames = [
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
self.assertTrue(is_safetensors_compatible(filenames))
def test_transformer_model_is_not_compatible_variant(self):
filenames = [
@@ -126,9 +111,45 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformers_is_compatible_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model-00001-of-00002.safetensors",
"text_encoder/model-00002-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_transformers_is_compatible_variant_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_sharded(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model-00002-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_variant_sharded(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_only_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

View File

@@ -551,37 +551,94 @@ class DownloadTests(unittest.TestCase):
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)
def test_download_broken_variant(self):
for use_safetensors in [False, True]:
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
def test_download_safetensors_only_variant_exists_for_model(self):
variant = None
use_safetensors = True
assert "Error no file name" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download(
# text encoder is missing no variant weights, so the following can't work
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
use_safetensors=use_safetensors,
cache_dir=tmpdirname,
variant="fp16",
variant=variant,
use_safetensors=use_safetensors,
)
assert "Error no file name" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants",
use_safetensors=use_safetensors,
cache_dir=tmpdirname,
variant="fp16",
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
def test_download_bin_only_variant_exists_for_model(self):
variant = None
use_safetensors = False
# text encoder is missing Non-variant weights, so the following can't work
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
assert "Error no file name" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants",
use_safetensors=use_safetensors,
cache_dir=tmpdirname,
variant="fp16",
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
def test_download_safetensors_variant_does_not_exist_for_model(self):
variant = "no_ema"
use_safetensors = True
# text encoder is missing no_ema variant weights, so the following can't work
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
assert "Error no file name" in str(error_context.exception)
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
def test_download_bin_variant_does_not_exist_for_model(self):
variant = "no_ema"
use_safetensors = False
# text encoder is missing no_ema variant weights, so the following can't work
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
assert "Error no file name" in str(error_context.exception)
def test_local_save_load_index(self):
prompt = "hello"