Compare commits

...

12 Commits

Author SHA1 Message Date
sayakpaul
abf24f116e resolve conflicts. 2024-03-06 09:50:39 +05:30
Sayak Paul
da441cefc6 Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-26 15:38:43 +05:30
sayakpaul
4d2ca28e24 ditto for pipelines. 2024-02-26 15:16:19 +05:30
sayakpaul
d7d757a387 make single file loader cleaner models 2024-02-26 15:14:35 +05:30
sayakpaul
e9e41981d7 fix: posix 2024-02-19 13:18:34 +05:30
Sayak Paul
807fa22bfc Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-19 12:19:54 +05:30
Sayak Paul
e3881c3bd9 Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-18 14:47:35 +05:30
sayakpaul
45ab4399cc Empty-Commit 2024-02-15 17:14:58 +05:30
sayakpaul
9c734f78e8 fix: condition for loading_info 2024-02-15 17:02:54 +05:30
sayakpaul
46f4c4399c add proper error handling through loadable classes check. 2024-02-15 16:56:29 +05:30
sayakpaul
64998bca1b support models too 2024-02-15 16:29:28 +05:30
sayakpaul
34efcc2034 feat: support single file checkpoint from from_pretrained() 2024-02-15 15:58:47 +05:30
7 changed files with 577 additions and 501 deletions

View File

@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
if torch_dtype is not None: if torch_dtype is not None:
vae = vae.to(torch_dtype) vae = vae.to(torch_dtype)
vae.eval()
return vae return vae

View File

@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
if torch_dtype is not None: if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype) controlnet = controlnet.to(torch_dtype)
controlnet.eval()
return controlnet return controlnet

View File

@@ -39,6 +39,7 @@ from ..utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_single_file_checkpoint,
is_torch_version, is_torch_version,
logging, logging,
) )
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
if is_torch_version(">=", "1.9.0"): if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True _LOW_CPU_MEM_USAGE_DEFAULT = True
else: else:
@@ -497,102 +500,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", None) if is_single_file_checkpoint(pretrained_model_name_or_path):
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
force_download = kwargs.pop("force_download", False) raise ValueError(
from_flax = kwargs.pop("from_flax", False) f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
resume_download = kwargs.pop("resume_download", False) )
proxies = kwargs.pop("proxies", None) logger.info("Single file checkpoint detected...")
output_loading_info = kwargs.pop("output_loading_info", False) model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
local_files_only = kwargs.pop("local_files_only", None) model = model.eval()
token = kwargs.pop("token", None) return model
revision = kwargs.pop("revision", None) else:
torch_dtype = kwargs.pop("torch_dtype", None) cache_dir = kwargs.pop("cache_dir", None)
subfolder = kwargs.pop("subfolder", None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
device_map = kwargs.pop("device_map", None) force_download = kwargs.pop("force_download", False)
max_memory = kwargs.pop("max_memory", None) from_flax = kwargs.pop("from_flax", False)
offload_folder = kwargs.pop("offload_folder", None) resume_download = kwargs.pop("resume_download", False)
offload_state_dict = kwargs.pop("offload_state_dict", False) proxies = kwargs.pop("proxies", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) output_loading_info = kwargs.pop("output_loading_info", False)
variant = kwargs.pop("variant", None) local_files_only = kwargs.pop("local_files_only", None)
use_safetensors = kwargs.pop("use_safetensors", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available(): if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False low_cpu_mem_usage = False
logger.warning( logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n." " install accelerate\n```\n."
) )
if device_map is not None and not is_accelerate_available(): if device_map is not None and not is_accelerate_available():
raise NotImplementedError( raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`." " `device_map=None`. You can install accelerate with `pip install accelerate`."
) )
# Check if we can handle device_map and dispatching the weights # Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"): if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError( raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`." " `device_map=None`."
) )
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError( raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
if low_cpu_mem_usage is False and device_map is not None: if low_cpu_mem_usage is False and device_map is not None:
raise ValueError( raise ValueError(
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
) )
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path config_path = pretrained_model_name_or_path
user_agent = { user_agent = {
"diffusers": __version__, "diffusers": __version__,
"file_type": "model", "file_type": "model",
"framework": "pytorch", "framework": "pytorch",
} }
# load config # load config
config, unused_kwargs, commit_hash = cls.load_config( config, unused_kwargs, commit_hash = cls.load_config(
config_path, config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
user_agent=user_agent,
**kwargs,
)
# load model
model_file = None
if from_flax:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
@@ -600,40 +591,20 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
token=token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, **kwargs,
) )
model = cls.from_config(config, **unused_kwargs)
# Convert the weights # load model
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model model_file = None
if from_flax:
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except IOError as e:
if not allow_pickle:
raise e
pass
if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant), weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
@@ -645,76 +616,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
) )
model = cls.from_config(config, **unused_kwargs)
if low_cpu_mem_usage: # Convert the weights
# Instantiate model with empty weights from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu model = load_flax_checkpoint_in_pytorch_model(model, model_file)
if device_map is None: else:
param_device = "cpu" if use_safetensors:
state_dict = load_state_dict(model_file, variant=variant) try:
model._convert_deprecated_attention_blocks(state_dict) model_file = _get_model_file(
# move the params from meta device to cpu pretrained_model_name_or_path,
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
if len(missing_keys) > 0: cache_dir=cache_dir,
raise ValueError( force_download=force_download,
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" resume_download=resume_download,
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" proxies=proxies,
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" local_files_only=local_files_only,
" those weights or else make sure your checkpoint file is correct." token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
) )
except IOError as e:
unexpected_keys = load_model_dict_into_meta( if not allow_pickle:
model, raise e
state_dict, pass
device=param_device, if model_file is None:
dtype=torch_dtype, model_file = _get_model_file(
model_name_or_path=pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
) )
if cls._keys_to_ignore_on_load_unexpected is not None: if low_cpu_mem_usage:
for pat in cls._keys_to_ignore_on_load_unexpected: # Instantiate model with empty weights
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
if len(unexpected_keys) > 0: # if device_map is None, load the state dict and move the params from meta device to the cpu
logger.warn( if device_map is None:
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" param_device = "cpu"
) state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
else: # else let accelerate handle loading and dispatching. # move the params from meta device to cpu
# Load weights and dispatch according to the device_map missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
# by default the device_map is None and the weights are loaded on the CPU if len(missing_keys) > 0:
try: raise ValueError(
accelerate.load_checkpoint_and_dispatch( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
model, f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
model_file, " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
device_map, " those weights or else make sure your checkpoint file is correct."
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
if "'Attention' object has no attribute" in str(e):
logger.warn(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
) )
model._temp_convert_self_to_deprecated_attention_blocks()
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
try:
accelerate.load_checkpoint_and_dispatch( accelerate.load_checkpoint_and_dispatch(
model, model,
model_file, model_file,
@@ -724,52 +709,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_state_dict=offload_state_dict, offload_state_dict=offload_state_dict,
dtype=torch_dtype, dtype=torch_dtype,
) )
model._undo_temp_convert_self_to_deprecated_attention_blocks() except AttributeError as e:
else: # When using accelerate loading, we do not have the ability to load the state
raise e # dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
loading_info = { if "'Attention' object has no attribute" in str(e):
"missing_keys": [], logger.warn(
"unexpected_keys": [], f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
"mismatched_keys": [], " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
"error_msgs": [], " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
} " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
else: " please also re-upload it or open a PR on the original repository."
model = cls.from_config(config, **unused_kwargs) )
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
else:
raise e
state_dict = load_state_dict(model_file, variant=variant) loading_info = {
model._convert_deprecated_attention_blocks(state_dict) "missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model = cls.from_config(config, **unused_kwargs)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( state_dict = load_state_dict(model_file, variant=variant)
model, model._convert_deprecated_attention_blocks(state_dict)
state_dict,
model_file, model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
pretrained_model_name_or_path, model,
ignore_mismatched_sizes=ignore_mismatched_sizes, state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
) )
elif torch_dtype is not None:
model = model.to(torch_dtype)
loading_info = { model.register_to_config(_name_or_path=pretrained_model_name_or_path)
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): # Set model in evaluation mode to deactivate DropOut modules by default
raise ValueError( model.eval()
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." if output_loading_info:
) return model, loading_info
elif torch_dtype is not None:
model = model.to(torch_dtype)
model.register_to_config(_name_or_path=pretrained_model_name_or_path) return model
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info
return model
@classmethod @classmethod
def _load_pretrained_model( def _load_pretrained_model(

View File

@@ -57,6 +57,7 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_peft_available, is_peft_available,
is_single_file_checkpoint,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
logging, logging,
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
}, },
} }
SINGLE_FILE_LOADABLE_CLASSES = {
"StableDiffusionPipeline",
"StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
}
ALL_IMPORTABLE_CLASSES = {} ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES: for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@@ -1056,308 +1071,334 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> pipeline.scheduler = scheduler >>> pipeline.scheduler = scheduler
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", None) if is_single_file_checkpoint(pretrained_model_name_or_path):
resume_download = kwargs.pop("resume_download", False) if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
if pretrained_model_name_or_path.count("/") > 1:
raise ValueError( raise ValueError(
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"' f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
" is neither a valid local path nor a valid repo id. Please check the parameter." " is neither a valid local path nor a valid repo id. Please check the parameter."
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
) )
cached_folder = cls.download( logger.info("Single file checkpoint detected...")
pretrained_model_name_or_path, model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
cache_dir=cache_dir, return model
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
from_flax=from_flax,
use_safetensors=use_safetensors,
use_onnx=use_onnx,
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant,
load_connected_pipeline=load_connected_pipeline,
**kwargs,
)
else: else:
cached_folder = pretrained_model_name_or_path cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
config_dict = cls.load_config(cached_folder) if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
# pop out "_ignore_files" as it is only needed for download logger.warning(
config_dict.pop("_ignore_files", None) "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
# 2. Define which model components should load variants " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
# We retrieve the information by matching whether variant " install accelerate\n```\n."
# model checkpoints exist in the subfolders
model_variants = {}
if variant is not None:
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
)
if variant_exists:
model_variants[folder] = variant
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
custom_class_name = None
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
custom_class_name = config_dict["_class_name"][1]
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
# DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
pipeline_class = StableDiffusionInpaintPipelineLegacy
deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
# 4. Define expected modules given pipeline signature
# and define non-None initialized modules (=`init_kwargs`)
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs and make sure that optional component modules are filtered out
init_kwargs = {
k: init_dict.pop(k)
for k in optional_kwargs
if k in init_dict and k not in pipeline_class._optional_components
}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
" separately if you need it."
)
# 5. Throw nice warnings / errors for fast accelerate loading
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
# import it here to avoid circular import
from diffusers import pipelines
# 6. Load each module in the pipeline
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
# 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# 6.3 Use passed sub model or load class_name from library_name
if name in passed_class_obj:
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
) )
loaded_sub_model = passed_class_obj[name] if device_map is not None and not is_torch_version(">=", "1.9.0"):
else: raise NotImplementedError(
# load sub model "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
loaded_sub_model = load_sub_model( " `device_map=None`."
library_name=library_name, )
class_name=class_name,
importable_classes=importable_classes, if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
pipelines=pipelines, raise NotImplementedError(
is_pipeline_module=is_pipeline_module, "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
pipeline_class=pipeline_class, " `low_cpu_mem_usage=False`."
torch_dtype=torch_dtype, )
provider=provider,
sess_options=sess_options, if low_cpu_mem_usage is False and device_map is not None:
device_map=device_map, raise ValueError(
max_memory=max_memory, f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
offload_folder=offload_folder, " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
offload_state_dict=offload_state_dict, )
model_variants=model_variants,
name=name, # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
if pretrained_model_name_or_path.count("/") > 1:
raise ValueError(
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
)
cached_folder = cls.download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
from_flax=from_flax, from_flax=from_flax,
use_safetensors=use_safetensors,
use_onnx=use_onnx,
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant, variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage, load_connected_pipeline=load_connected_pipeline,
cached_folder=cached_folder, **kwargs,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
) )
else:
cached_folder = pretrained_model_name_or_path
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) config_dict = cls.load_config(cached_folder)
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): # pop out "_ignore_files" as it is only needed for download
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) config_dict.pop("_ignore_files", None)
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
load_kwargs = {
"cache_dir": cache_dir,
"resume_download": resume_download,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"torch_dtype": torch_dtype,
"custom_pipeline": custom_pipeline,
"custom_revision": custom_revision,
"provider": provider,
"sess_options": sess_options,
"device_map": device_map,
"max_memory": max_memory,
"offload_folder": offload_folder,
"offload_state_dict": offload_state_dict,
"low_cpu_mem_usage": low_cpu_mem_usage,
"variant": variant,
"use_safetensors": use_safetensors,
}
def get_connected_passed_kwargs(prefix): # 2. Define which model components should load variants
connected_passed_class_obj = { # We retrieve the information by matching whether variant
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix # model checkpoints exist in the subfolders
} model_variants = {}
connected_passed_pipe_kwargs = { if variant is not None:
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix for folder in os.listdir(cached_folder):
} folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
)
if variant_exists:
model_variants[folder] = variant
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} # 3. Load the pipeline class, if using custom module then load it from the hub
return connected_passed_kwargs # if we load from explicit class, let's use it
custom_class_name = None
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
custom_class_name = config_dict["_class_name"][1]
connected_pipes = { pipeline_class = _get_pipeline_class(
prefix: DiffusionPipeline.from_pretrained( cls,
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix) config_dict,
) load_connected_pipeline=load_connected_pipeline,
for prefix, repo_id in connected_pipes.items() custom_pipeline=custom_pipeline,
if repo_id is not None class_name=custom_class_name,
} cache_dir=cache_dir,
revision=custom_revision,
for prefix, connected_pipe in connected_pipes.items():
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
init_kwargs.update(
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
)
# 7. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
) )
# 8. Instantiate the pipeline # DEPRECATED: To be removed in 1.0.0
model = pipeline_class(**init_kwargs) if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
# 9. Save where the model was instantiated from pipeline_class = StableDiffusionInpaintPipelineLegacy
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
return model deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
# 4. Define expected modules given pipeline signature
# and define non-None initialized modules (=`init_kwargs`)
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs and make sure that optional component modules are filtered out
init_kwargs = {
k: init_dict.pop(k)
for k in optional_kwargs
if k in init_dict and k not in pipeline_class._optional_components
}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
" separately if you need it."
)
# 5. Throw nice warnings / errors for fast accelerate loading
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
# import it here to avoid circular import
from diffusers import pipelines
# 6. Load each module in the pipeline
for name, (library_name, class_name) in logging.tqdm(
init_dict.items(), desc="Loading pipeline components..."
):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
# 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# 6.3 Use passed sub model or load class_name from library_name
if name in passed_class_obj:
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
maybe_raise_or_warn(
library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
)
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
importable_classes=importable_classes,
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
provider=provider,
sess_options=sess_options,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
model_variants=model_variants,
name=name,
from_flax=from_flax,
variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = {
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
}
load_kwargs = {
"cache_dir": cache_dir,
"resume_download": resume_download,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"torch_dtype": torch_dtype,
"custom_pipeline": custom_pipeline,
"custom_revision": custom_revision,
"provider": provider,
"sess_options": sess_options,
"device_map": device_map,
"max_memory": max_memory,
"offload_folder": offload_folder,
"offload_state_dict": offload_state_dict,
"low_cpu_mem_usage": low_cpu_mem_usage,
"variant": variant,
"use_safetensors": use_safetensors,
}
def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w
for k, w in passed_class_obj.items()
if k.split("_")[0] == prefix
}
connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w
for k, w in passed_pipe_kwargs.items()
if k.split("_")[0] == prefix
}
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
return connected_passed_kwargs
connected_pipes = {
prefix: DiffusionPipeline.from_pretrained(
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
)
for prefix, repo_id in connected_pipes.items()
if repo_id is not None
}
for prefix, connected_pipe in connected_pipes.items():
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
init_kwargs.update(
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
)
# 7. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# 8. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
# 9. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
return model
@property @property
def name_or_path(self) -> str: def name_or_path(self) -> str:

View File

@@ -19,6 +19,7 @@ from packaging import version
from .. import __version__ from .. import __version__
from .constants import ( from .constants import (
_ACCEPTED_SINGLE_FILE_FORMATS,
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
@@ -83,7 +84,7 @@ from .import_utils import (
is_xformers_available, is_xformers_available,
requires_backends, requires_backends,
) )
from .loading_utils import load_image from .loading_utils import is_single_file_checkpoint, load_image
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .peft_utils import ( from .peft_utils import (

View File

@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -1,10 +1,28 @@
import os import os
from typing import Callable, Union from typing import Callable, Union
from urllib.parse import urlparse
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
import requests import requests
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
def is_single_file_checkpoint(filepath):
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
filepath = str(filepath)
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
if is_valid_url(filepath):
return True
elif os.path.isfile(filepath):
return True
return False
def load_image( def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None