mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
6 Commits
test-disab
...
v0.26.2-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d52558c15 | ||
|
|
3efe355d52 | ||
|
|
08e6558ab8 | ||
|
|
1547720209 | ||
|
|
674d43fd68 | ||
|
|
e7a16666ea |
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.25.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
|
|
||||||
class MarigoldDepthOutput(BaseOutput):
|
class MarigoldDepthOutput(BaseOutput):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -538,7 +538,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return hidden_states, output_states
|
return hidden_states, output_states
|
||||||
|
|
||||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
|
|
||||||
output_states = ()
|
output_states = ()
|
||||||
@@ -634,7 +634,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
def hacked_UpBlock2D_forward(
|
||||||
|
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||||
|
):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
for i, resnet in enumerate(self.resnets):
|
for i, resnet in enumerate(self.resnets):
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
|
|||||||
@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
|||||||
|
|
||||||
return hidden_states, output_states
|
return hidden_states, output_states
|
||||||
|
|
||||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
|
|
||||||
output_states = ()
|
output_states = ()
|
||||||
@@ -603,7 +603,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
def hacked_UpBlock2D_forward(
|
||||||
|
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||||
|
):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
for i, resnet in enumerate(self.resnets):
|
for i, resnet in enumerate(self.resnets):
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.25.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.26.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.26.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.26.0.dev0"
|
__version__ = "0.26.2"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.only_cross_attention = only_cross_attention
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||||
|
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||||
|
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||||
|
self.use_layer_norm = norm_type == "layer_norm"
|
||||||
|
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||||
|
|
||||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from .models import UNet2DConditionModel
|
from .models import UNet2DConditionModel
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -13,6 +12,7 @@ from .utils import (
|
|||||||
convert_state_dict_to_peft,
|
convert_state_dict_to_peft,
|
||||||
deprecate,
|
deprecate,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
|
is_torchvision_available,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,6 +23,9 @@ if is_transformers_available():
|
|||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
from peft import set_peft_model_state_dict
|
from peft import set_peft_model_state_dict
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
"""
|
"""
|
||||||
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
|
|||||||
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
||||||
transform.
|
transform.
|
||||||
"""
|
"""
|
||||||
|
if not is_torchvision_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
||||||
|
)
|
||||||
|
|
||||||
if interpolation_type == "bilinear":
|
if interpolation_type == "bilinear":
|
||||||
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
||||||
elif interpolation_type == "bicubic":
|
elif interpolation_type == "bicubic":
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ from .import_utils import (
|
|||||||
is_torch_version,
|
is_torch_version,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
is_torchsde_available,
|
is_torchsde_available,
|
||||||
|
is_torchvision_available,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
is_transformers_version,
|
is_transformers_version,
|
||||||
is_unidecode_available,
|
is_unidecode_available,
|
||||||
|
|||||||
@@ -278,6 +278,13 @@ try:
|
|||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
_peft_available = False
|
_peft_available = False
|
||||||
|
|
||||||
|
_torchvision_available = importlib.util.find_spec("torchvision") is not None
|
||||||
|
try:
|
||||||
|
_torchvision_version = importlib_metadata.version("torchvision")
|
||||||
|
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_torchvision_available = False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
@@ -367,6 +374,10 @@ def is_peft_available():
|
|||||||
return _peft_available
|
return _peft_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchvision_available():
|
||||||
|
return _torchvision_available
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
FLAX_IMPORT_ERROR = """
|
FLAX_IMPORT_ERROR = """
|
||||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||||
|
|||||||
Reference in New Issue
Block a user