Compare commits

...

3 Commits

Author SHA1 Message Date
Patrick von Platen
6bbde99c4a Release: v0.17.0 2023-06-08 16:55:06 +02:00
Patrick von Platen
5916743b22 Fix loading if unexpected keys are present (#3720)
* Fix loading

* make style
2023-06-08 16:52:45 +02:00
Patrick von Platen
7ddc4a1a9f Fix custom releases (#3708)
* Fix custom releases

* make style
2023-06-07 18:35:41 +02:00
18 changed files with 36 additions and 17 deletions

View File

@@ -55,7 +55,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -56,7 +56,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
# Cache compiled models across invocations of this script. # Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -64,7 +64,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -51,7 +51,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -77,7 +77,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -28,7 +28,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0.dev0") check_min_version("0.17.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -227,7 +227,7 @@ install_requires = [
setup( setup(
name="diffusers", name="diffusers",
version="0.17.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.17.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers", description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.17.0.dev0" __version__ = "0.17.0"
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .utils import ( from .utils import (

View File

@@ -17,6 +17,7 @@
import inspect import inspect
import itertools import itertools
import os import os
import re
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
@@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module):
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -608,6 +610,7 @@ class ModelMixin(torch.nn.Module):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct." " those weights or else make sure your checkpoint file is correct."
) )
unexpected_keys = []
empty_state_dict = model.state_dict() empty_state_dict = model.state_dict()
for param_name, param in state_dict.items(): for param_name, param in state_dict.items():
@@ -615,6 +618,10 @@ class ModelMixin(torch.nn.Module):
inspect.signature(set_module_tensor_to_device).parameters.keys() inspect.signature(set_module_tensor_to_device).parameters.keys()
) )
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape: if empty_state_dict[param_name].shape != param.shape:
raise ValueError( raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
@@ -626,6 +633,16 @@ class ModelMixin(torch.nn.Module):
) )
else: else:
set_module_tensor_to_device(model, param_name, param_device, value=param) set_module_tensor_to_device(model, param_name, param_device, value=param)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: # else let accelerate handle loading and dispatching. else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map # Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU # by default the device_map is None and the weights are loaded on the CPU

View File

@@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
dot-product/softmax to float() when training with mixed precision. dot-product/softmax to float() when training with mixed precision.
""" """
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@@ -21,12 +21,12 @@ import os
import re import re
import shutil import shutil
import sys import sys
from distutils.version import StrictVersion
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from urllib import request from urllib import request
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
from packaging import version
from .. import __version__ from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
@@ -43,7 +43,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_diffusers_versions(): def get_diffusers_versions():
url = "https://pypi.org/pypi/diffusers/json" url = "https://pypi.org/pypi/diffusers/json"
releases = json.loads(request.urlopen(url).read())["releases"].keys() releases = json.loads(request.urlopen(url).read())["releases"].keys()
return sorted(releases, key=StrictVersion) return sorted(releases, key=lambda x: version.Version(x))
def init_hf_modules(): def init_hf_modules():