mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
2 Commits
deprecated
...
v0.7.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea5eebd7c | ||
|
|
7c2a58fd4d |
2
setup.py
2
setup.py
@@ -210,7 +210,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.7.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.7.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .utils import (
|
||||
is_accelerate_available,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_onnx_available,
|
||||
@@ -10,20 +9,13 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.7.0"
|
||||
__version__ = "0.7.1"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
from .utils import logging
|
||||
|
||||
|
||||
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
|
||||
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
|
||||
if is_torch_available() and not is_accelerate_available():
|
||||
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
|
||||
raise ImportError(error_msg)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
|
||||
@@ -21,15 +21,20 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
import accelerate
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -41,6 +46,12 @@ else:
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@@ -319,6 +330,21 @@ class ModelMixin(torch.nn.Module):
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -25,7 +25,6 @@ import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
@@ -43,6 +42,8 @@ from .utils import (
|
||||
WEIGHTS_NAME,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
@@ -397,6 +398,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
|
||||
@@ -31,6 +31,7 @@ from .import_utils import (
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_unidecode_available,
|
||||
requires_backends,
|
||||
|
||||
@@ -272,21 +272,6 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class ModelMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class UNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class UNet2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class VQModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_constant_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_scheduler(*args, **kwargs):
|
||||
requires_backends(get_scheduler, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DanceDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDIMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDPMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class KarrasVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class LDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class PNDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class RePaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDPMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class IPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class KarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class PNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class RePaintScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class SchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class VQDiffusionScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class EMAModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
@@ -15,11 +15,14 @@
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
import importlib.util
|
||||
import operator as op
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from packaging import version
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from . import logging
|
||||
|
||||
@@ -40,6 +43,8 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
_torch_version = "N/A"
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
@@ -309,3 +314,36 @@ class DummyObject(type):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
|
||||
|
||||
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
|
||||
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares a library version to some requirement using a given operation.
|
||||
library_or_version (`str` or `packaging.version.Version`):
|
||||
A library name or a version to check.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`.
|
||||
requirement_version (`str`):
|
||||
The version to compare the library version against
|
||||
"""
|
||||
if operation not in STR_OPERATION_TO_FUNC.keys():
|
||||
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
|
||||
operation = STR_OPERATION_TO_FUNC[operation]
|
||||
if isinstance(library_or_version, str):
|
||||
library_or_version = parse(importlib_metadata.version(library_or_version))
|
||||
return operation(library_or_version, parse(requirement_version))
|
||||
|
||||
|
||||
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
|
||||
def is_torch_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares the current PyTorch version to a given reference with an operation.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A string version of PyTorch
|
||||
"""
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
@@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
|
||||
def test_read_init(self):
|
||||
objects = read_init()
|
||||
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
||||
self.assertIn("torch_and_accelerate", objects)
|
||||
self.assertIn("torch", objects)
|
||||
self.assertIn("torch_and_transformers", objects)
|
||||
self.assertIn("flax_and_transformers", objects)
|
||||
self.assertIn("torch_and_transformers_and_onnx", objects)
|
||||
|
||||
# Likewise, we can't assert on the exact content of a key
|
||||
self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
|
||||
self.assertIn("UNet2DModel", objects["torch"])
|
||||
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
||||
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
||||
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
||||
|
||||
Reference in New Issue
Block a user