Compare commits

...

6 Commits

Author SHA1 Message Date
Dhruv Nair
ac1580a48d fix test related to cached requests 2023-08-09 15:56:12 +00:00
Dhruv Nair
cdbc3317c6 Merge branch 'main' into safetensors-default 2023-08-09 12:22:21 +00:00
Dhruv Nair
d8287a198e Merge branch 'main' into safetensors-default 2023-08-09 08:56:48 +00:00
Dhruv Nair
43e993d470 update pipeline tests for safetensor default 2023-08-08 16:39:55 +00:00
Abhipsha Das
24b7bcc468 Modifying import_utils.py 2023-08-07 19:10:26 -04:00
Abhipsha Das
7fde4a2460 [WIP] Remove code snippets containing is_safetensors_available() 2023-08-07 18:19:27 -04:00
12 changed files with 58 additions and 139 deletions

View File

@@ -2,14 +2,8 @@ import glob
import os import os
from typing import Dict, List, Union from typing import Dict, List, Union
import safetensors.torch
import torch import torch
from diffusers.utils import is_safetensors_available
if is_safetensors_available():
import safetensors.torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline, __version__ from diffusers import DiffusionPipeline, __version__
@@ -229,14 +223,14 @@ class CheckpointMergerPipeline(DiffusionPipeline):
update_theta_0 = getattr(module, "load_state_dict") update_theta_0 = getattr(module, "load_state_dict")
theta_1 = ( theta_1 = (
safetensors.torch.load_file(checkpoint_path_1) safetensors.torch.load_file(checkpoint_path_1)
if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors")) if (checkpoint_path_1.endswith(".safetensors"))
else torch.load(checkpoint_path_1, map_location="cpu") else torch.load(checkpoint_path_1, map_location="cpu")
) )
theta_2 = None theta_2 = None
if checkpoint_path_2: if checkpoint_path_2:
theta_2 = ( theta_2 = (
safetensors.torch.load_file(checkpoint_path_2) safetensors.torch.load_file(checkpoint_path_2)
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors")) if (checkpoint_path_2.endswith(".safetensors"))
else torch.load(checkpoint_path_2, map_location="cpu") else torch.load(checkpoint_path_2, map_location="cpu")
) )

View File

@@ -38,7 +38,7 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import is_omegaconf_available, is_safetensors_available from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING from diffusers.utils.import_utils import BACKENDS_MAPPING
@@ -824,9 +824,6 @@ def load_pipeline_from_original_audioldm_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open from safetensors import safe_open
checkpoint = {} checkpoint = {}

View File

@@ -1,12 +1,6 @@
import argparse import argparse
from diffusers.utils import is_safetensors_available import safetensors.torch
if is_safetensors_available():
import safetensors.torch
else:
raise ImportError("Please install `safetensors`.")
from diffusers import AutoencoderTiny from diffusers import AutoencoderTiny

View File

@@ -27,7 +27,7 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from packaging import version from packaging import version
from ..utils import is_safetensors_available, logging from ..utils import logging
from . import BaseDiffusersCLICommand from . import BaseDiffusersCLICommand
@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
self.local_ckpt_dir = f"/tmp/{ckpt_id}" self.local_ckpt_dir = f"/tmp/{ckpt_id}"
self.fp16 = fp16 self.fp16 = fp16
if is_safetensors_available():
self.use_safetensors = use_safetensors self.use_safetensors = use_safetensors
else:
raise ImportError(
"When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
)
if not self.use_safetensors and not self.fp16: if not self.use_safetensors and not self.fp16:
raise NotImplementedError( raise NotImplementedError(

View File

@@ -22,6 +22,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import requests import requests
import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -34,16 +35,12 @@ from .utils import (
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_omegaconf_available, is_omegaconf_available,
is_safetensors_available,
is_transformers_available, is_transformers_available,
logging, logging,
) )
from .utils.import_utils import BACKENDS_MAPPING from .utils.import_utils import BACKENDS_MAPPING
if is_safetensors_available():
import safetensors
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
@@ -261,14 +258,10 @@ class UNet2DConditionLoadersMixin:
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
is_network_alphas_none = network_alphas is None is_network_alphas_none = network_alphas is None
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {
@@ -757,14 +750,9 @@ class TextualInversionLoaderMixin:
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {
@@ -1014,14 +1002,9 @@ class LoraLoaderMixin:
unet_config = kwargs.pop("unet_config", None) unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {
@@ -1853,7 +1836,7 @@ class FromSingleFileMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
pipeline_name = cls.__name__ pipeline_name = cls.__name__
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
@@ -2050,7 +2033,7 @@ class FromOriginalVAEMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors" from_safetensors = file_extension == "safetensors"
@@ -2223,7 +2206,7 @@ class FromOriginalControlnetMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors" from_safetensors = file_extension == "safetensors"

View File

@@ -21,6 +21,7 @@ import re
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch import torch
from torch import Tensor, device, nn from torch import Tensor, device, nn
@@ -36,7 +37,6 @@ from ..utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_safetensors_available,
is_torch_version, is_torch_version,
logging, logging,
) )
@@ -56,9 +56,6 @@ if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version from accelerate.utils.versions import is_torch_version
if is_safetensors_available():
import safetensors
def get_parameter_device(parameter: torch.nn.Module): def get_parameter_device(parameter: torch.nn.Module):
try: try:
@@ -296,9 +293,6 @@ class ModelMixin(torch.nn.Module):
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
""" """
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
@@ -454,14 +448,9 @@ class ModelMixin(torch.nn.Module):
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available(): if low_cpu_mem_usage and not is_accelerate_available():

View File

@@ -52,7 +52,6 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module, is_compiled_module,
is_safetensors_available,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
logging, logging,
@@ -899,7 +898,7 @@ class DiffusionPipeline(ConfigMixin):
offload_state_dict = kwargs.pop("offload_state_dict", False) offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
@@ -1311,14 +1310,9 @@ class DiffusionPipeline(ConfigMixin):
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
allow_patterns = None allow_patterns = None

View File

@@ -50,7 +50,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
UnCLIPScheduler, UnCLIPScheduler,
) )
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging from ...utils import is_accelerate_available, is_omegaconf_available, logging
from ...utils.import_utils import BACKENDS_MAPPING from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder from ..paint_by_example import PaintByExampleImageEncoder
@@ -1225,9 +1225,6 @@ def download_from_original_stable_diffusion_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
checkpoint = safe_load(checkpoint_path, device="cpu") checkpoint = safe_load(checkpoint_path, device="cpu")
@@ -1650,9 +1647,6 @@ def download_controlnet_from_original_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open from safetensors import safe_open
checkpoint = {} checkpoint = {}

View File

@@ -64,7 +64,6 @@ from .import_utils import (
is_note_seq_available, is_note_seq_available,
is_omegaconf_available, is_omegaconf_available,
is_onnx_available, is_onnx_available,
is_safetensors_available,
is_scipy_available, is_scipy_available,
is_tensorboard_available, is_tensorboard_available,
is_tf_available, is_tf_available,

View File

@@ -306,10 +306,6 @@ def is_torch_available():
return _torch_available return _torch_available
def is_safetensors_available():
return _safetensors_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available

View File

@@ -100,14 +100,15 @@ class ModelUtilsTest(unittest.TestCase):
if torch_device == "mps": if torch_device == "mps":
return return
import diffusers use_safetensors = False
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m: with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained( UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
download_requests = [r.method for r in m.request_history] download_requests = [r.method for r in m.request_history]
@@ -116,7 +117,10 @@ class ModelUtilsTest(unittest.TestCase):
with requests_mock.mock(real_http=True) as m: with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained( UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
cache_requests = [r.method for r in m.request_history] cache_requests = [r.method for r in m.request_history]
@@ -124,8 +128,6 @@ class ModelUtilsTest(unittest.TestCase):
"HEAD" == cache_requests[0] and len(cache_requests) == 1 "HEAD" == cache_requests[0] and len(cache_requests) == 1
), "We should call only `model_info` to check for _commit hash and `send_telemetry`" ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
diffusers.utils.import_utils._safetensors_available = True
def test_weight_overwrite(self): def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained( UNet2DConditionModel.from_pretrained(

View File

@@ -472,15 +472,13 @@ class DownloadTests(unittest.TestCase):
assert False, "Parameters not the same!" assert False, "Parameters not the same!"
def test_download_from_variant_folder(self): def test_download_from_variant_folder(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers other_format = ".bin" if use_safetensors else ".safetensors"
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname "hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
@@ -492,21 +490,18 @@ class DownloadTests(unittest.TestCase):
# no variants # no variants
assert not any(len(f.split(".")) == 3 for f in files) assert not any(len(f.split(".")) == 3 for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_all(self): def test_download_variant_all(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers other_format = ".bin" if use_safetensors else ".safetensors"
this_format = ".safetensors" if use_safetensors else ".bin"
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "fp16" variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
@@ -520,21 +515,18 @@ class DownloadTests(unittest.TestCase):
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
assert not any(f.endswith(other_format) for f in files) assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_partly(self): def test_download_variant_partly(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers other_format = ".bin" if use_safetensors else ".safetensors"
this_format = ".safetensors" if use_safetensors else ".bin"
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "no_ema" variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
@@ -551,13 +543,8 @@ class DownloadTests(unittest.TestCase):
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files) assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_broken_variant(self): def test_download_broken_variant(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work # text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for variant in [None, "no_ema"]: for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context: with self.assertRaises(OSError) as error_context:
@@ -566,6 +553,7 @@ class DownloadTests(unittest.TestCase):
"hf-internal-testing/stable-diffusion-broken-variants", "hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname, cache_dir=tmpdirname,
variant=variant, variant=variant,
use_safetensors=use_safetensors,
) )
assert "Error no file name" in str(error_context.exception) assert "Error no file name" in str(error_context.exception)
@@ -573,7 +561,10 @@ class DownloadTests(unittest.TestCase):
# text encoder has fp16 variants so we can load it # text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16" "hf-internal-testing/stable-diffusion-broken-variants",
use_safetensors=use_safetensors,
cache_dir=tmpdirname,
variant="fp16",
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
@@ -584,8 +575,6 @@ class DownloadTests(unittest.TestCase):
assert len(files) == 15, f"We should only download 15 files, not {len(files)}" assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant # only unet has "no_ema" variant
diffusers.utils.import_utils._safetensors_available = True
def test_local_save_load_index(self): def test_local_save_load_index(self):
prompt = "hello" prompt = "hello"
for variant in [None, "fp16"]: for variant in [None, "fp16"]:
@@ -961,10 +950,6 @@ class PipelineFastTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def dummy_image(self): def dummy_image(self):
batch_size = 1 batch_size = 1
num_channels = 3 num_channels = 3
@@ -1319,14 +1304,13 @@ class PipelineFastTests(unittest.TestCase):
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin")) assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
def test_no_safetensors_download_when_doing_pytorch(self): def test_no_safetensors_download_when_doing_pytorch(self):
# mock diffusers safetensors not available use_safetensors = False
import diffusers
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
_ = StableDiffusionPipeline.from_pretrained( _ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
path = os.path.join( path = os.path.join(
@@ -1341,8 +1325,6 @@ class PipelineFastTests(unittest.TestCase):
# pytorch does # pytorch does
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin")) assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
diffusers.utils.import_utils._safetensors_available = True
def test_optional_components(self): def test_optional_components(self):
unet = self.dummy_cond_unet() unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")