Compare commits

...

6 Commits

Author SHA1 Message Date
sayakpaul
7d52558c15 Release: v0.26.2-patch 2024-02-06 07:36:31 +05:30
YiYi Xu
3efe355d52 add self.use_ada_layer_norm_* params back to BasicTransformerBlock (#6841)
fix sd reference community ppeline

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-02-06 07:34:36 +05:30
sayakpaul
08e6558ab8 Release: v0.26.1-patch 2024-02-02 14:42:23 +05:30
YiYi Xu
1547720209 add is_torchvision_available (#6800)
* add

* remove transformer

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-02-02 14:40:36 +05:30
Patrick von Platen
674d43fd68 fix torchvision import (#6796) 2024-02-01 00:15:09 +02:00
yiyixuxu
e7a16666ea Release: v0.26.0 2024-01-31 11:31:57 -10:00
42 changed files with 71 additions and 42 deletions

View File

@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -72,7 +72,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -40,8 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
class MarigoldDepthOutput(BaseOutput): class MarigoldDepthOutput(BaseOutput):
""" """

View File

@@ -538,7 +538,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
return hidden_states, output_states return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None): def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
eps = 1e-6 eps = 1e-6
output_states = () output_states = ()
@@ -634,7 +634,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
return hidden_states return hidden_states
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
):
eps = 1e-6 eps = 1e-6
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
# pop res hidden states # pop res hidden states

View File

@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return hidden_states, output_states return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None): def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
eps = 1e-6 eps = 1e-6
output_states = () output_states = ()
@@ -603,7 +603,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return hidden_states return hidden_states
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
):
eps = 1e-6 eps = 1e-6
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
# pop res hidden states # pop res hidden states

View File

@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -77,7 +77,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -62,7 +62,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
# Cache compiled models across invocations of this script. # Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -66,7 +66,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -65,7 +65,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -54,7 +54,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -50,7 +50,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -63,7 +63,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -77,7 +77,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.25.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup( setup(
name="diffusers", name="diffusers",
version="0.26.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.26.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.", description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.26.0.dev0" __version__ = "0.26.2"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING

View File

@@ -158,6 +158,12 @@ class BasicTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.only_cross_attention = only_cross_attention self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError( raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"

View File

@@ -5,7 +5,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from torchvision import transforms
from .models import UNet2DConditionModel from .models import UNet2DConditionModel
from .utils import ( from .utils import (
@@ -13,6 +12,7 @@ from .utils import (
convert_state_dict_to_peft, convert_state_dict_to_peft,
deprecate, deprecate,
is_peft_available, is_peft_available,
is_torchvision_available,
is_transformers_available, is_transformers_available,
) )
@@ -23,6 +23,9 @@ if is_transformers_available():
if is_peft_available(): if is_peft_available():
from peft import set_peft_model_state_dict from peft import set_peft_model_state_dict
if is_torchvision_available():
from torchvision import transforms
def set_seed(seed: int): def set_seed(seed: int):
""" """
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform. transform.
""" """
if not is_torchvision_available():
raise ImportError(
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
)
if interpolation_type == "bilinear": if interpolation_type == "bilinear":
interpolation_mode = transforms.InterpolationMode.BILINEAR interpolation_mode = transforms.InterpolationMode.BILINEAR
elif interpolation_type == "bicubic": elif interpolation_type == "bicubic":

View File

@@ -75,6 +75,7 @@ from .import_utils import (
is_torch_version, is_torch_version,
is_torch_xla_available, is_torch_xla_available,
is_torchsde_available, is_torchsde_available,
is_torchvision_available,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
is_unidecode_available, is_unidecode_available,

View File

@@ -278,6 +278,13 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_peft_available = False _peft_available = False
_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
_torchvision_version = importlib_metadata.version("torchvision")
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
except importlib_metadata.PackageNotFoundError:
_torchvision_available = False
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
@@ -367,6 +374,10 @@ def is_peft_available():
return _peft_available return _peft_available
def is_torchvision_available():
return _torchvision_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the