mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
6 Commits
modular-re
...
safetensor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac1580a48d | ||
|
|
cdbc3317c6 | ||
|
|
d8287a198e | ||
|
|
43e993d470 | ||
|
|
24b7bcc468 | ||
|
|
7fde4a2460 |
@@ -2,14 +2,8 @@ import glob
|
||||
import os
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from diffusers.utils import is_safetensors_available
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from diffusers import DiffusionPipeline, __version__
|
||||
@@ -229,14 +223,14 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
update_theta_0 = getattr(module, "load_state_dict")
|
||||
theta_1 = (
|
||||
safetensors.torch.load_file(checkpoint_path_1)
|
||||
if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors"))
|
||||
if (checkpoint_path_1.endswith(".safetensors"))
|
||||
else torch.load(checkpoint_path_1, map_location="cpu")
|
||||
)
|
||||
theta_2 = None
|
||||
if checkpoint_path_2:
|
||||
theta_2 = (
|
||||
safetensors.torch.load_file(checkpoint_path_2)
|
||||
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))
|
||||
if (checkpoint_path_2.endswith(".safetensors"))
|
||||
else torch.load(checkpoint_path_2, map_location="cpu")
|
||||
)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import is_omegaconf_available, is_safetensors_available
|
||||
from diffusers.utils import is_omegaconf_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
@@ -824,9 +824,6 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
import argparse
|
||||
|
||||
from diffusers.utils import is_safetensors_available
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
else:
|
||||
raise ImportError("Please install `safetensors`.")
|
||||
import safetensors.torch
|
||||
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_safetensors_available, logging
|
||||
from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
||||
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
||||
self.fp16 = fp16
|
||||
|
||||
if is_safetensors_available():
|
||||
self.use_safetensors = use_safetensors
|
||||
else:
|
||||
raise ImportError(
|
||||
"When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
|
||||
)
|
||||
self.use_safetensors = use_safetensors
|
||||
|
||||
if not self.use_safetensors and not self.fp16:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -22,6 +22,7 @@ from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -34,16 +35,12 @@ from .utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_safetensors_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
@@ -261,14 +258,10 @@ class UNet2DConditionLoadersMixin:
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
is_network_alphas_none = network_alphas is None
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
@@ -757,14 +750,9 @@ class TextualInversionLoaderMixin:
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
@@ -1014,14 +1002,9 @@ class LoraLoaderMixin:
|
||||
unet_config = kwargs.pop("unet_config", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
@@ -1853,7 +1836,7 @@ class FromSingleFileMixin:
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
@@ -2050,7 +2033,7 @@ class FromOriginalVAEMixin:
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
@@ -2223,7 +2206,7 @@ class FromOriginalControlnetMixin:
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
@@ -21,6 +21,7 @@ import re
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from torch import Tensor, device, nn
|
||||
|
||||
@@ -36,7 +37,6 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -56,9 +56,6 @@ if is_accelerate_available():
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
@@ -296,9 +293,6 @@ class ModelMixin(torch.nn.Module):
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
"""
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@@ -454,14 +448,9 @@ class ModelMixin(torch.nn.Module):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
|
||||
@@ -52,7 +52,6 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -899,7 +898,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
@@ -1311,14 +1310,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
allow_patterns = None
|
||||
|
||||
@@ -50,7 +50,7 @@ from ...schedulers import (
|
||||
PNDMScheduler,
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils import is_accelerate_available, is_omegaconf_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
@@ -1225,9 +1225,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||
@@ -1650,9 +1647,6 @@ def download_controlnet_from_original_ckpt(
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
|
||||
@@ -64,7 +64,6 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_omegaconf_available,
|
||||
is_onnx_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_tf_available,
|
||||
|
||||
@@ -306,10 +306,6 @@ def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_safetensors_available():
|
||||
return _safetensors_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
@@ -100,14 +100,15 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = False
|
||||
use_safetensors = False
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="unet",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
@@ -116,7 +117,10 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="unet",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
@@ -124,8 +128,6 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 1
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
|
||||
@@ -472,15 +472,13 @@ class DownloadTests(unittest.TestCase):
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
def test_download_from_variant_folder(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
for use_safetensors in [False, True]:
|
||||
other_format = ".bin" if use_safetensors else ".safetensors"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
@@ -492,21 +490,18 @@ class DownloadTests(unittest.TestCase):
|
||||
# no variants
|
||||
assert not any(len(f.split(".")) == 3 for f in files)
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_download_variant_all(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
this_format = ".safetensors" if safe_avail else ".bin"
|
||||
for use_safetensors in [False, True]:
|
||||
other_format = ".bin" if use_safetensors else ".safetensors"
|
||||
this_format = ".safetensors" if use_safetensors else ".bin"
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
@@ -520,21 +515,18 @@ class DownloadTests(unittest.TestCase):
|
||||
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_download_variant_partly(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
this_format = ".safetensors" if safe_avail else ".bin"
|
||||
for use_safetensors in [False, True]:
|
||||
other_format = ".bin" if use_safetensors else ".safetensors"
|
||||
this_format = ".safetensors" if use_safetensors else ".bin"
|
||||
variant = "no_ema"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
@@ -551,13 +543,8 @@ class DownloadTests(unittest.TestCase):
|
||||
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_download_broken_variant(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
for use_safetensors in [False, True]:
|
||||
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
|
||||
for variant in [None, "no_ema"]:
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
@@ -566,6 +553,7 @@ class DownloadTests(unittest.TestCase):
|
||||
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
assert "Error no file name" in str(error_context.exception)
|
||||
@@ -573,7 +561,10 @@ class DownloadTests(unittest.TestCase):
|
||||
# text encoder has fp16 variants so we can load it
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
|
||||
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=tmpdirname,
|
||||
variant="fp16",
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
@@ -584,8 +575,6 @@ class DownloadTests(unittest.TestCase):
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_local_save_load_index(self):
|
||||
prompt = "hello"
|
||||
for variant in [None, "fp16"]:
|
||||
@@ -961,10 +950,6 @@ class PipelineFastTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
@@ -1319,14 +1304,13 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
||||
|
||||
def test_no_safetensors_download_when_doing_pytorch(self):
|
||||
# mock diffusers safetensors not available
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = False
|
||||
use_safetensors = False
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
path = os.path.join(
|
||||
@@ -1341,8 +1325,6 @@ class PipelineFastTests(unittest.TestCase):
|
||||
# pytorch does
|
||||
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_optional_components(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
|
||||
Reference in New Issue
Block a user