Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
cfe6b99a4e disable getattr 2023-12-20 13:55:37 +05:30
sayakpaul
58d85bd8a8 let 2023-12-20 13:50:59 +05:30
sayakpaul
ea90d03146 use logger. 2023-12-20 13:46:27 +05:30
sayakpaul
27304fd8ae start debugging 2023-12-20 13:40:57 +05:30

View File

@@ -202,23 +202,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def __init__(self):
super().__init__()
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""
# def __getattr__(self, name: str) -> Any:
# """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
# config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
# __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
# """
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
is_attribute = name in self.__dict__
# internal_keys = self.__dict__.keys()
# is_in_config = "_internal_dict" in internal_keys and hasattr(self.__dict__["_internal_dict"], name)
# is_attribute = name in internal_keys
if is_in_config and not is_attribute:
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
return self._internal_dict[name]
# if is_in_config and not is_attribute:
# deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
# deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
# return self._internal_dict[name]
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
return super().__getattr__(name)
# # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
# return super().__getattr__(name)
@property
def is_gradient_checkpointing(self) -> bool: