Compare commits

...

3 Commits

Author SHA1 Message Date
Patrick von Platen
6bbde99c4a Release: v0.17.0 2023-06-08 16:55:06 +02:00
Patrick von Platen
5916743b22 Fix loading if unexpected keys are present (#3720)
* Fix loading

* make style
2023-06-08 16:52:45 +02:00
Patrick von Platen
7ddc4a1a9f Fix custom releases (#3708)
* Fix custom releases

* make style
2023-06-07 18:35:41 +02:00
18 changed files with 36 additions and 17 deletions

View File

@@ -55,7 +55,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = logging.getLogger(__name__)

View File

@@ -56,7 +56,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__)

View File

@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -64,7 +64,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__)

View File

@@ -51,7 +51,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -77,7 +77,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = logging.getLogger(__name__)

View File

@@ -28,7 +28,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0")
check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -227,7 +227,7 @@ install_requires = [
setup(
name="diffusers",
version="0.17.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.17.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.17.0.dev0"
__version__ = "0.17.0"
from .configuration_utils import ConfigMixin
from .utils import (

View File

@@ -17,6 +17,7 @@
import inspect
import itertools
import os
import re
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
@@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module):
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
def __init__(self):
super().__init__()
@@ -608,6 +610,7 @@ class ModelMixin(torch.nn.Module):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
@@ -615,6 +618,10 @@ class ModelMixin(torch.nn.Module):
inspect.signature(set_module_tensor_to_device).parameters.keys()
)
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
@@ -626,6 +633,16 @@ class ModelMixin(torch.nn.Module):
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU

View File

@@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
dot-product/softmax to float() when training with mixed precision.
"""
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
@register_to_config
def __init__(
self,

View File

@@ -21,12 +21,12 @@ import os
import re
import shutil
import sys
from distutils.version import StrictVersion
from pathlib import Path
from typing import Dict, Optional, Union
from urllib import request
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
from packaging import version
from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
@@ -43,7 +43,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_diffusers_versions():
url = "https://pypi.org/pypi/diffusers/json"
releases = json.loads(request.urlopen(url).read())["releases"].keys()
return sorted(releases, key=StrictVersion)
return sorted(releases, key=lambda x: version.Version(x))
def init_hf_modules():