mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
2 Commits
unload-sin
...
v0.7.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea5eebd7c | ||
|
|
7c2a58fd4d |
2
setup.py
2
setup.py
@@ -210,7 +210,7 @@ install_requires = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.7.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.7.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="Diffusers",
|
description="Diffusers",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
is_accelerate_available,
|
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_inflect_available,
|
is_inflect_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
@@ -10,20 +9,13 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.7.0"
|
__version__ = "0.7.1"
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .onnx_utils import OnnxRuntimeModel
|
from .onnx_utils import OnnxRuntimeModel
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
|
|
||||||
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
|
|
||||||
if is_torch_available() and not is_accelerate_available():
|
|
||||||
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
|
|
||||||
raise ImportError(error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||||
|
|||||||
@@ -21,15 +21,20 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
|
|
||||||
import accelerate
|
|
||||||
from accelerate.utils import set_module_tensor_to_device
|
|
||||||
from accelerate.utils.versions import is_torch_version
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
|
from .utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_torch_version,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -41,6 +46,12 @@ else:
|
|||||||
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
import accelerate
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from accelerate.utils.versions import is_torch_version
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: torch.nn.Module):
|
def get_parameter_device(parameter: torch.nn.Module):
|
||||||
try:
|
try:
|
||||||
return next(parameter.parameters()).device
|
return next(parameter.parameters()).device
|
||||||
@@ -319,6 +330,21 @@ class ModelMixin(torch.nn.Module):
|
|||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warn(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_map is not None and not is_accelerate_available():
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||||
|
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||||
|
)
|
||||||
|
|
||||||
# Check if we can handle device_map and dispatching the weights
|
# Check if we can handle device_map and dispatching the weights
|
||||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import torch
|
|||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import PIL
|
import PIL
|
||||||
from accelerate.utils.versions import is_torch_version
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -43,6 +42,8 @@ from .utils import (
|
|||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
BaseOutput,
|
BaseOutput,
|
||||||
deprecate,
|
deprecate,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_torch_version,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@@ -397,6 +398,15 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warn(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from .import_utils import (
|
|||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_version,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
is_unidecode_available,
|
is_unidecode_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
|
|||||||
@@ -272,21 +272,6 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class VQDiffusionPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMScheduler(metaclass=DummyObject):
|
class DDIMScheduler(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,452 +0,0 @@
|
|||||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from ..utils import DummyObject, requires_backends
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMixin(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKL(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer2DModel(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class UNet1DModel(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class UNet2DConditionModel(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class UNet2DModel(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class VQModel(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_constant_schedule(*args, **kwargs):
|
|
||||||
requires_backends(get_constant_schedule, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_constant_schedule_with_warmup(*args, **kwargs):
|
|
||||||
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_cosine_schedule_with_warmup(*args, **kwargs):
|
|
||||||
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
|
|
||||||
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_linear_schedule_with_warmup(*args, **kwargs):
|
|
||||||
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
|
|
||||||
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(*args, **kwargs):
|
|
||||||
requires_backends(get_scheduler, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class DanceDiffusionPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class DDPMPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class KarrasVePipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class LDMPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class PNDMPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class RePaintPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreSdeVePipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class DDPMScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class EulerDiscreteScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class IPNDMScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class KarrasVeScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class PNDMScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class RePaintScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class VQDiffusionScheduler(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
|
|
||||||
class EMAModel(metaclass=DummyObject):
|
|
||||||
_backends = ["torch", "accelerate"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["torch", "accelerate"])
|
|
||||||
@@ -15,11 +15,14 @@
|
|||||||
Import utilities: Utilities related to imports and our lazy inits.
|
Import utilities: Utilities related to imports and our lazy inits.
|
||||||
"""
|
"""
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import operator as op
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
|
||||||
@@ -40,6 +43,8 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|||||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||||
|
|
||||||
|
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||||
|
|
||||||
_torch_version = "N/A"
|
_torch_version = "N/A"
|
||||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||||
_torch_available = importlib.util.find_spec("torch") is not None
|
_torch_available = importlib.util.find_spec("torch") is not None
|
||||||
@@ -309,3 +314,36 @@ class DummyObject(type):
|
|||||||
if key.startswith("_"):
|
if key.startswith("_"):
|
||||||
return super().__getattr__(cls, key)
|
return super().__getattr__(cls, key)
|
||||||
requires_backends(cls, cls._backends)
|
requires_backends(cls, cls._backends)
|
||||||
|
|
||||||
|
|
||||||
|
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
|
||||||
|
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Compares a library version to some requirement using a given operation.
|
||||||
|
library_or_version (`str` or `packaging.version.Version`):
|
||||||
|
A library name or a version to check.
|
||||||
|
operation (`str`):
|
||||||
|
A string representation of an operator, such as `">"` or `"<="`.
|
||||||
|
requirement_version (`str`):
|
||||||
|
The version to compare the library version against
|
||||||
|
"""
|
||||||
|
if operation not in STR_OPERATION_TO_FUNC.keys():
|
||||||
|
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
|
||||||
|
operation = STR_OPERATION_TO_FUNC[operation]
|
||||||
|
if isinstance(library_or_version, str):
|
||||||
|
library_or_version = parse(importlib_metadata.version(library_or_version))
|
||||||
|
return operation(library_or_version, parse(requirement_version))
|
||||||
|
|
||||||
|
|
||||||
|
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
|
||||||
|
def is_torch_version(operation: str, version: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Compares the current PyTorch version to a given reference with an operation.
|
||||||
|
operation (`str`):
|
||||||
|
A string representation of an operator, such as `">"` or `"<="`
|
||||||
|
version (`str`):
|
||||||
|
A string version of PyTorch
|
||||||
|
"""
|
||||||
|
return compare_versions(parse(_torch_version), operation, version)
|
||||||
|
|||||||
@@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
|
|||||||
def test_read_init(self):
|
def test_read_init(self):
|
||||||
objects = read_init()
|
objects = read_init()
|
||||||
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
||||||
self.assertIn("torch_and_accelerate", objects)
|
self.assertIn("torch", objects)
|
||||||
self.assertIn("torch_and_transformers", objects)
|
self.assertIn("torch_and_transformers", objects)
|
||||||
self.assertIn("flax_and_transformers", objects)
|
self.assertIn("flax_and_transformers", objects)
|
||||||
self.assertIn("torch_and_transformers_and_onnx", objects)
|
self.assertIn("torch_and_transformers_and_onnx", objects)
|
||||||
|
|
||||||
# Likewise, we can't assert on the exact content of a key
|
# Likewise, we can't assert on the exact content of a key
|
||||||
self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
|
self.assertIn("UNet2DModel", objects["torch"])
|
||||||
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
||||||
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
||||||
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
||||||
|
|||||||
Reference in New Issue
Block a user