mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 06:54:32 +08:00
Compare commits
9 Commits
v0.32.0
...
v0.26.3-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66f94eaa0c | ||
|
|
c1f2609a40 | ||
|
|
552634d688 | ||
|
|
7d52558c15 | ||
|
|
3efe355d52 | ||
|
|
08e6558ab8 | ||
|
|
1547720209 | ||
|
|
674d43fd68 | ||
|
|
e7a16666ea |
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.25.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
|
|
||||||
class MarigoldDepthOutput(BaseOutput):
|
class MarigoldDepthOutput(BaseOutput):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -538,7 +538,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return hidden_states, output_states
|
return hidden_states, output_states
|
||||||
|
|
||||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
|
|
||||||
output_states = ()
|
output_states = ()
|
||||||
@@ -634,7 +634,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
def hacked_UpBlock2D_forward(
|
||||||
|
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||||
|
):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
for i, resnet in enumerate(self.resnets):
|
for i, resnet in enumerate(self.resnets):
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
|
|||||||
@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
|||||||
|
|
||||||
return hidden_states, output_states
|
return hidden_states, output_states
|
||||||
|
|
||||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
|
|
||||||
output_states = ()
|
output_states = ()
|
||||||
@@ -603,7 +603,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
def hacked_UpBlock2D_forward(
|
||||||
|
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||||
|
):
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
for i, resnet in enumerate(self.resnets):
|
for i, resnet in enumerate(self.resnets):
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.25.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.26.0.dev0")
|
check_min_version("0.26.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.26.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.26.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.26.0.dev0"
|
__version__ = "0.26.3"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ class FromOriginalVAEMixin:
|
|||||||
- A link to the `.ckpt` file (for example
|
- A link to the `.ckpt` file (for example
|
||||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||||
- A path to a *file* containing all pipeline weights.
|
- A path to a *file* containing all pipeline weights.
|
||||||
|
config_file (`str`, *optional*):
|
||||||
|
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||||
|
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||||
dtype is automatically derived from the model's weights.
|
dtype is automatically derived from the model's weights.
|
||||||
@@ -65,6 +68,13 @@ class FromOriginalVAEMixin:
|
|||||||
image_size (`int`, *optional*, defaults to 512):
|
image_size (`int`, *optional*, defaults to 512):
|
||||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||||
|
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||||
|
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||||
|
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||||
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||||
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||||
|
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||||
|
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||||
@@ -92,6 +102,7 @@ class FromOriginalVAEMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
original_config_file = kwargs.pop("original_config_file", None)
|
original_config_file = kwargs.pop("original_config_file", None)
|
||||||
|
config_file = kwargs.pop("config_file", None)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
@@ -103,6 +114,13 @@ class FromOriginalVAEMixin:
|
|||||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||||
|
|
||||||
class_name = cls.__name__
|
class_name = cls.__name__
|
||||||
|
|
||||||
|
if (config_file is not None) and (original_config_file is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||||
|
)
|
||||||
|
|
||||||
|
original_config_file = original_config_file or config_file
|
||||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
@@ -118,7 +136,10 @@ class FromOriginalVAEMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_size = kwargs.pop("image_size", None)
|
image_size = kwargs.pop("image_size", None)
|
||||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||||
|
component = create_diffusers_vae_model_from_ldm(
|
||||||
|
class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor
|
||||||
|
)
|
||||||
vae = component["vae"]
|
vae = component["vae"]
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
vae = vae.to(torch_dtype)
|
vae = vae.to(torch_dtype)
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LDM_VAE_KEY = "first_stage_model."
|
LDM_VAE_KEY = "first_stage_model."
|
||||||
|
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||||
LDM_UNET_KEY = "model.diffusion_model."
|
LDM_UNET_KEY = "model.diffusion_model."
|
||||||
LDM_CONTROLNET_KEY = "control_model."
|
LDM_CONTROLNET_KEY = "control_model."
|
||||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||||
@@ -518,7 +519,10 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
|||||||
Creates a config for the diffusers based on the config of the LDM model.
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
"""
|
"""
|
||||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||||
scaling_factor = scaling_factor or original_config["model"]["params"]["scale_factor"]
|
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||||
|
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||||
|
elif scaling_factor is None:
|
||||||
|
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||||
|
|
||||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||||
@@ -1174,7 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
|
|||||||
|
|
||||||
|
|
||||||
def create_diffusers_vae_model_from_ldm(
|
def create_diffusers_vae_model_from_ldm(
|
||||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125
|
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None
|
||||||
):
|
):
|
||||||
# import here to avoid circular imports
|
# import here to avoid circular imports
|
||||||
from ..models import AutoencoderKL
|
from ..models import AutoencoderKL
|
||||||
|
|||||||
@@ -158,6 +158,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.only_cross_attention = only_cross_attention
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||||
|
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||||
|
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||||
|
self.use_layer_norm = norm_type == "layer_norm"
|
||||||
|
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||||
|
|
||||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample_max_value: float = 1.0,
|
sample_max_value: float = 1.0,
|
||||||
algorithm_type: str = "dpmsolver++",
|
algorithm_type: str = "dpmsolver++",
|
||||||
solver_type: str = "midpoint",
|
solver_type: str = "midpoint",
|
||||||
lower_order_final: bool = True,
|
lower_order_final: bool = False,
|
||||||
use_karras_sigmas: Optional[bool] = False,
|
use_karras_sigmas: Optional[bool] = False,
|
||||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||||
lambda_min_clipped: float = -float("inf"),
|
lambda_min_clipped: float = -float("inf"),
|
||||||
@@ -232,7 +232,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
orders = [1, 2, 3] * (steps // 3) + [1, 2]
|
orders = [1, 2, 3] * (steps // 3) + [1, 2]
|
||||||
elif order == 2:
|
elif order == 2:
|
||||||
if steps % 2 == 0:
|
if steps % 2 == 0:
|
||||||
orders = [1, 2] * (steps // 2)
|
orders = [1, 2] * (steps // 2 - 1) + [1, 1]
|
||||||
else:
|
else:
|
||||||
orders = [1, 2] * (steps // 2) + [1]
|
orders = [1, 2] * (steps // 2) + [1]
|
||||||
elif order == 1:
|
elif order == 1:
|
||||||
@@ -301,7 +301,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
|
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
|
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
|
||||||
)
|
)
|
||||||
self.register_to_config(lower_order_final=True)
|
self.register_to_config(lower_order_final=True)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from .models import UNet2DConditionModel
|
from .models import UNet2DConditionModel
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -13,6 +12,7 @@ from .utils import (
|
|||||||
convert_state_dict_to_peft,
|
convert_state_dict_to_peft,
|
||||||
deprecate,
|
deprecate,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
|
is_torchvision_available,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,6 +23,9 @@ if is_transformers_available():
|
|||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
from peft import set_peft_model_state_dict
|
from peft import set_peft_model_state_dict
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
"""
|
"""
|
||||||
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
|
|||||||
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
||||||
transform.
|
transform.
|
||||||
"""
|
"""
|
||||||
|
if not is_torchvision_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
||||||
|
)
|
||||||
|
|
||||||
if interpolation_type == "bilinear":
|
if interpolation_type == "bilinear":
|
||||||
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
||||||
elif interpolation_type == "bicubic":
|
elif interpolation_type == "bicubic":
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ from .import_utils import (
|
|||||||
is_torch_version,
|
is_torch_version,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
is_torchsde_available,
|
is_torchsde_available,
|
||||||
|
is_torchvision_available,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
is_transformers_version,
|
is_transformers_version,
|
||||||
is_unidecode_available,
|
is_unidecode_available,
|
||||||
|
|||||||
@@ -278,6 +278,13 @@ try:
|
|||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
_peft_available = False
|
_peft_available = False
|
||||||
|
|
||||||
|
_torchvision_available = importlib.util.find_spec("torchvision") is not None
|
||||||
|
try:
|
||||||
|
_torchvision_version = importlib_metadata.version("torchvision")
|
||||||
|
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_torchvision_available = False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
@@ -367,6 +374,10 @@ def is_peft_available():
|
|||||||
return _peft_available
|
return _peft_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchvision_available():
|
||||||
|
return _torchvision_available
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
FLAX_IMPORT_ERROR = """
|
FLAX_IMPORT_ERROR = """
|
||||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||||
|
|||||||
Reference in New Issue
Block a user