mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 17:34:44 +08:00
Compare commits
6 Commits
modular-re
...
safetensor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac1580a48d | ||
|
|
cdbc3317c6 | ||
|
|
d8287a198e | ||
|
|
43e993d470 | ||
|
|
24b7bcc468 | ||
|
|
7fde4a2460 |
@@ -2,14 +2,8 @@ import glob
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.utils import is_safetensors_available
|
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
import safetensors.torch
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, __version__
|
from diffusers import DiffusionPipeline, __version__
|
||||||
@@ -229,14 +223,14 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
|||||||
update_theta_0 = getattr(module, "load_state_dict")
|
update_theta_0 = getattr(module, "load_state_dict")
|
||||||
theta_1 = (
|
theta_1 = (
|
||||||
safetensors.torch.load_file(checkpoint_path_1)
|
safetensors.torch.load_file(checkpoint_path_1)
|
||||||
if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors"))
|
if (checkpoint_path_1.endswith(".safetensors"))
|
||||||
else torch.load(checkpoint_path_1, map_location="cpu")
|
else torch.load(checkpoint_path_1, map_location="cpu")
|
||||||
)
|
)
|
||||||
theta_2 = None
|
theta_2 = None
|
||||||
if checkpoint_path_2:
|
if checkpoint_path_2:
|
||||||
theta_2 = (
|
theta_2 = (
|
||||||
safetensors.torch.load_file(checkpoint_path_2)
|
safetensors.torch.load_file(checkpoint_path_2)
|
||||||
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))
|
if (checkpoint_path_2.endswith(".safetensors"))
|
||||||
else torch.load(checkpoint_path_2, map_location="cpu")
|
else torch.load(checkpoint_path_2, map_location="cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from diffusers import (
|
|||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.utils import is_omegaconf_available, is_safetensors_available
|
from diffusers.utils import is_omegaconf_available
|
||||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||||
|
|
||||||
|
|
||||||
@@ -824,9 +824,6 @@ def load_pipeline_from_original_audioldm_ckpt(
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
if from_safetensors:
|
if from_safetensors:
|
||||||
if not is_safetensors_available():
|
|
||||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
||||||
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
|
|||||||
@@ -1,12 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from diffusers.utils import is_safetensors_available
|
import safetensors.torch
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
import safetensors.torch
|
|
||||||
else:
|
|
||||||
raise ImportError("Please install `safetensors`.")
|
|
||||||
|
|
||||||
from diffusers import AutoencoderTiny
|
from diffusers import AutoencoderTiny
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import torch
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ..utils import is_safetensors_available, logging
|
from ..utils import logging
|
||||||
from . import BaseDiffusersCLICommand
|
from . import BaseDiffusersCLICommand
|
||||||
|
|
||||||
|
|
||||||
@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
|||||||
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
self.use_safetensors = use_safetensors
|
self.use_safetensors = use_safetensors
|
||||||
else:
|
|
||||||
raise ImportError(
|
|
||||||
"When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.use_safetensors and not self.fp16:
|
if not self.use_safetensors and not self.fp16:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from pathlib import Path
|
|||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -34,16 +35,12 @@ from .utils import (
|
|||||||
deprecate,
|
deprecate,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_omegaconf_available,
|
is_omegaconf_available,
|
||||||
is_safetensors_available,
|
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from .utils.import_utils import BACKENDS_MAPPING
|
from .utils.import_utils import BACKENDS_MAPPING
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
import safetensors
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
@@ -261,14 +258,10 @@ class UNet2DConditionLoadersMixin:
|
|||||||
network_alphas = kwargs.pop("network_alphas", None)
|
network_alphas = kwargs.pop("network_alphas", None)
|
||||||
is_network_alphas_none = network_alphas is None
|
is_network_alphas_none = network_alphas is None
|
||||||
|
|
||||||
if use_safetensors and not is_safetensors_available():
|
|
||||||
raise ValueError(
|
|
||||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
|
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = is_safetensors_available()
|
use_safetensors = True
|
||||||
allow_pickle = True
|
allow_pickle = True
|
||||||
|
|
||||||
user_agent = {
|
user_agent = {
|
||||||
@@ -757,14 +750,9 @@ class TextualInversionLoaderMixin:
|
|||||||
weight_name = kwargs.pop("weight_name", None)
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
if use_safetensors and not is_safetensors_available():
|
|
||||||
raise ValueError(
|
|
||||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = is_safetensors_available()
|
use_safetensors = True
|
||||||
allow_pickle = True
|
allow_pickle = True
|
||||||
|
|
||||||
user_agent = {
|
user_agent = {
|
||||||
@@ -1014,14 +1002,9 @@ class LoraLoaderMixin:
|
|||||||
unet_config = kwargs.pop("unet_config", None)
|
unet_config = kwargs.pop("unet_config", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
if use_safetensors and not is_safetensors_available():
|
|
||||||
raise ValueError(
|
|
||||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = is_safetensors_available()
|
use_safetensors = True
|
||||||
allow_pickle = True
|
allow_pickle = True
|
||||||
|
|
||||||
user_agent = {
|
user_agent = {
|
||||||
@@ -1853,7 +1836,7 @@ class FromSingleFileMixin:
|
|||||||
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
pipeline_name = cls.__name__
|
pipeline_name = cls.__name__
|
||||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||||
@@ -2050,7 +2033,7 @@ class FromOriginalVAEMixin:
|
|||||||
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||||
from_safetensors = file_extension == "safetensors"
|
from_safetensors = file_extension == "safetensors"
|
||||||
@@ -2223,7 +2206,7 @@ class FromOriginalControlnetMixin:
|
|||||||
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||||
from_safetensors = file_extension == "safetensors"
|
from_safetensors = file_extension == "safetensors"
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import re
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device, nn
|
from torch import Tensor, device, nn
|
||||||
|
|
||||||
@@ -36,7 +37,6 @@ from ..utils import (
|
|||||||
_get_model_file,
|
_get_model_file,
|
||||||
deprecate,
|
deprecate,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_safetensors_available,
|
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@@ -56,9 +56,6 @@ if is_accelerate_available():
|
|||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
from accelerate.utils.versions import is_torch_version
|
from accelerate.utils.versions import is_torch_version
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
import safetensors
|
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: torch.nn.Module):
|
def get_parameter_device(parameter: torch.nn.Module):
|
||||||
try:
|
try:
|
||||||
@@ -296,9 +293,6 @@ class ModelMixin(torch.nn.Module):
|
|||||||
variant (`str`, *optional*):
|
variant (`str`, *optional*):
|
||||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||||
"""
|
"""
|
||||||
if safe_serialization and not is_safetensors_available():
|
|
||||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
|
||||||
|
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@@ -454,14 +448,9 @@ class ModelMixin(torch.nn.Module):
|
|||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
if use_safetensors and not is_safetensors_available():
|
|
||||||
raise ValueError(
|
|
||||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = is_safetensors_available()
|
use_safetensors = True
|
||||||
allow_pickle = True
|
allow_pickle = True
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ from ..utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_accelerate_version,
|
is_accelerate_version,
|
||||||
is_compiled_module,
|
is_compiled_module,
|
||||||
is_safetensors_available,
|
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -899,7 +898,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
# 1. Download the checkpoints and configs
|
||||||
@@ -1311,14 +1310,9 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
use_onnx = kwargs.pop("use_onnx", None)
|
use_onnx = kwargs.pop("use_onnx", None)
|
||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
|
|
||||||
if use_safetensors and not is_safetensors_available():
|
|
||||||
raise ValueError(
|
|
||||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = is_safetensors_available()
|
use_safetensors = True
|
||||||
allow_pickle = True
|
allow_pickle = True
|
||||||
|
|
||||||
allow_patterns = None
|
allow_patterns = None
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from ...schedulers import (
|
|||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
UnCLIPScheduler,
|
UnCLIPScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
|
from ...utils import is_accelerate_available, is_omegaconf_available, logging
|
||||||
from ...utils.import_utils import BACKENDS_MAPPING
|
from ...utils.import_utils import BACKENDS_MAPPING
|
||||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
from ..paint_by_example import PaintByExampleImageEncoder
|
from ..paint_by_example import PaintByExampleImageEncoder
|
||||||
@@ -1225,9 +1225,6 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
if from_safetensors:
|
if from_safetensors:
|
||||||
if not is_safetensors_available():
|
|
||||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
||||||
|
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
|
|
||||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||||
@@ -1650,9 +1647,6 @@ def download_controlnet_from_original_ckpt(
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
if from_safetensors:
|
if from_safetensors:
|
||||||
if not is_safetensors_available():
|
|
||||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
||||||
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ from .import_utils import (
|
|||||||
is_note_seq_available,
|
is_note_seq_available,
|
||||||
is_omegaconf_available,
|
is_omegaconf_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_safetensors_available,
|
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
|
|||||||
@@ -306,10 +306,6 @@ def is_torch_available():
|
|||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
|
|
||||||
def is_safetensors_available():
|
|
||||||
return _safetensors_available
|
|
||||||
|
|
||||||
|
|
||||||
def is_tf_available():
|
def is_tf_available():
|
||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|||||||
@@ -100,14 +100,15 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
if torch_device == "mps":
|
if torch_device == "mps":
|
||||||
return
|
return
|
||||||
|
|
||||||
import diffusers
|
use_safetensors = False
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = False
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
with requests_mock.mock(real_http=True) as m:
|
with requests_mock.mock(real_http=True) as m:
|
||||||
UNet2DConditionModel.from_pretrained(
|
UNet2DConditionModel.from_pretrained(
|
||||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||||
|
subfolder="unet",
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
download_requests = [r.method for r in m.request_history]
|
download_requests = [r.method for r in m.request_history]
|
||||||
@@ -116,7 +117,10 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
with requests_mock.mock(real_http=True) as m:
|
with requests_mock.mock(real_http=True) as m:
|
||||||
UNet2DConditionModel.from_pretrained(
|
UNet2DConditionModel.from_pretrained(
|
||||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||||
|
subfolder="unet",
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_requests = [r.method for r in m.request_history]
|
cache_requests = [r.method for r in m.request_history]
|
||||||
@@ -124,8 +128,6 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
"HEAD" == cache_requests[0] and len(cache_requests) == 1
|
"HEAD" == cache_requests[0] and len(cache_requests) == 1
|
||||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def test_weight_overwrite(self):
|
def test_weight_overwrite(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||||
UNet2DConditionModel.from_pretrained(
|
UNet2DConditionModel.from_pretrained(
|
||||||
|
|||||||
@@ -472,15 +472,13 @@ class DownloadTests(unittest.TestCase):
|
|||||||
assert False, "Parameters not the same!"
|
assert False, "Parameters not the same!"
|
||||||
|
|
||||||
def test_download_from_variant_folder(self):
|
def test_download_from_variant_folder(self):
|
||||||
for safe_avail in [False, True]:
|
for use_safetensors in [False, True]:
|
||||||
import diffusers
|
other_format = ".bin" if use_safetensors else ".safetensors"
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
|
||||||
|
|
||||||
other_format = ".bin" if safe_avail else ".safetensors"
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
tmpdirname = StableDiffusionPipeline.download(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
|
"hf-internal-testing/stable-diffusion-all-variants",
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
@@ -492,21 +490,18 @@ class DownloadTests(unittest.TestCase):
|
|||||||
# no variants
|
# no variants
|
||||||
assert not any(len(f.split(".")) == 3 for f in files)
|
assert not any(len(f.split(".")) == 3 for f in files)
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def test_download_variant_all(self):
|
def test_download_variant_all(self):
|
||||||
for safe_avail in [False, True]:
|
for use_safetensors in [False, True]:
|
||||||
import diffusers
|
other_format = ".bin" if use_safetensors else ".safetensors"
|
||||||
|
this_format = ".safetensors" if use_safetensors else ".bin"
|
||||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
|
||||||
|
|
||||||
other_format = ".bin" if safe_avail else ".safetensors"
|
|
||||||
this_format = ".safetensors" if safe_avail else ".bin"
|
|
||||||
variant = "fp16"
|
variant = "fp16"
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
tmpdirname = StableDiffusionPipeline.download(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
"hf-internal-testing/stable-diffusion-all-variants",
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
variant=variant,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
@@ -520,21 +515,18 @@ class DownloadTests(unittest.TestCase):
|
|||||||
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
|
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
|
||||||
assert not any(f.endswith(other_format) for f in files)
|
assert not any(f.endswith(other_format) for f in files)
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def test_download_variant_partly(self):
|
def test_download_variant_partly(self):
|
||||||
for safe_avail in [False, True]:
|
for use_safetensors in [False, True]:
|
||||||
import diffusers
|
other_format = ".bin" if use_safetensors else ".safetensors"
|
||||||
|
this_format = ".safetensors" if use_safetensors else ".bin"
|
||||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
|
||||||
|
|
||||||
other_format = ".bin" if safe_avail else ".safetensors"
|
|
||||||
this_format = ".safetensors" if safe_avail else ".bin"
|
|
||||||
variant = "no_ema"
|
variant = "no_ema"
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
tmpdirname = StableDiffusionPipeline.download(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
"hf-internal-testing/stable-diffusion-all-variants",
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
variant=variant,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
@@ -551,13 +543,8 @@ class DownloadTests(unittest.TestCase):
|
|||||||
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||||
assert not any(f.endswith(other_format) for f in files)
|
assert not any(f.endswith(other_format) for f in files)
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def test_download_broken_variant(self):
|
def test_download_broken_variant(self):
|
||||||
for safe_avail in [False, True]:
|
for use_safetensors in [False, True]:
|
||||||
import diffusers
|
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
|
||||||
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
|
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
|
||||||
for variant in [None, "no_ema"]:
|
for variant in [None, "no_ema"]:
|
||||||
with self.assertRaises(OSError) as error_context:
|
with self.assertRaises(OSError) as error_context:
|
||||||
@@ -566,6 +553,7 @@ class DownloadTests(unittest.TestCase):
|
|||||||
"hf-internal-testing/stable-diffusion-broken-variants",
|
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||||
cache_dir=tmpdirname,
|
cache_dir=tmpdirname,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "Error no file name" in str(error_context.exception)
|
assert "Error no file name" in str(error_context.exception)
|
||||||
@@ -573,7 +561,10 @@ class DownloadTests(unittest.TestCase):
|
|||||||
# text encoder has fp16 variants so we can load it
|
# text encoder has fp16 variants so we can load it
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
tmpdirname = StableDiffusionPipeline.download(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
|
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
variant="fp16",
|
||||||
)
|
)
|
||||||
|
|
||||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
@@ -584,8 +575,6 @@ class DownloadTests(unittest.TestCase):
|
|||||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||||
# only unet has "no_ema" variant
|
# only unet has "no_ema" variant
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def test_local_save_load_index(self):
|
def test_local_save_load_index(self):
|
||||||
prompt = "hello"
|
prompt = "hello"
|
||||||
for variant in [None, "fp16"]:
|
for variant in [None, "fp16"]:
|
||||||
@@ -961,10 +950,6 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
import diffusers
|
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def dummy_image(self):
|
def dummy_image(self):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
num_channels = 3
|
num_channels = 3
|
||||||
@@ -1319,14 +1304,13 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
||||||
|
|
||||||
def test_no_safetensors_download_when_doing_pytorch(self):
|
def test_no_safetensors_download_when_doing_pytorch(self):
|
||||||
# mock diffusers safetensors not available
|
use_safetensors = False
|
||||||
import diffusers
|
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = False
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
_ = StableDiffusionPipeline.from_pretrained(
|
_ = StableDiffusionPipeline.from_pretrained(
|
||||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
|
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
@@ -1341,8 +1325,6 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
# pytorch does
|
# pytorch does
|
||||||
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
||||||
|
|
||||||
diffusers.utils.import_utils._safetensors_available = True
|
|
||||||
|
|
||||||
def test_optional_components(self):
|
def test_optional_components(self):
|
||||||
unet = self.dummy_cond_unet()
|
unet = self.dummy_cond_unet()
|
||||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||||
|
|||||||
Reference in New Issue
Block a user