Compare commits

...

12 Commits

Author SHA1 Message Date
sayakpaul
abf24f116e resolve conflicts. 2024-03-06 09:50:39 +05:30
Sayak Paul
da441cefc6 Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-26 15:38:43 +05:30
sayakpaul
4d2ca28e24 ditto for pipelines. 2024-02-26 15:16:19 +05:30
sayakpaul
d7d757a387 make single file loader cleaner models 2024-02-26 15:14:35 +05:30
sayakpaul
e9e41981d7 fix: posix 2024-02-19 13:18:34 +05:30
Sayak Paul
807fa22bfc Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-19 12:19:54 +05:30
Sayak Paul
e3881c3bd9 Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-18 14:47:35 +05:30
sayakpaul
45ab4399cc Empty-Commit 2024-02-15 17:14:58 +05:30
sayakpaul
9c734f78e8 fix: condition for loading_info 2024-02-15 17:02:54 +05:30
sayakpaul
46f4c4399c add proper error handling through loadable classes check. 2024-02-15 16:56:29 +05:30
sayakpaul
64998bca1b support models too 2024-02-15 16:29:28 +05:30
sayakpaul
34efcc2034 feat: support single file checkpoint from from_pretrained() 2024-02-15 15:58:47 +05:30
7 changed files with 577 additions and 501 deletions

View File

@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
if torch_dtype is not None: if torch_dtype is not None:
vae = vae.to(torch_dtype) vae = vae.to(torch_dtype)
vae.eval()
return vae return vae

View File

@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
if torch_dtype is not None: if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype) controlnet = controlnet.to(torch_dtype)
controlnet.eval()
return controlnet return controlnet

View File

@@ -39,6 +39,7 @@ from ..utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_single_file_checkpoint,
is_torch_version, is_torch_version,
logging, logging,
) )
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
if is_torch_version(">=", "1.9.0"): if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True _LOW_CPU_MEM_USAGE_DEFAULT = True
else: else:
@@ -497,6 +500,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
``` ```
""" """
if is_single_file_checkpoint(pretrained_model_name_or_path):
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
)
logger.info("Single file checkpoint detected...")
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
model = model.eval()
return model
else:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)

View File

@@ -57,6 +57,7 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_peft_available, is_peft_available,
is_single_file_checkpoint,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
logging, logging,
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
}, },
} }
SINGLE_FILE_LOADABLE_CLASSES = {
"StableDiffusionPipeline",
"StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
}
ALL_IMPORTABLE_CLASSES = {} ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES: for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@@ -1056,6 +1071,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> pipeline.scheduler = scheduler >>> pipeline.scheduler = scheduler
``` ```
""" """
if is_single_file_checkpoint(pretrained_model_name_or_path):
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
)
logger.info("Single file checkpoint detected...")
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
return model
else:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
@@ -1242,7 +1269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
from diffusers import pipelines from diffusers import pipelines
# 6. Load each module in the pipeline # 6. Load each module in the pipeline
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): for name, (library_name, class_name) in logging.tqdm(
init_dict.items(), desc="Loading pipeline components..."
):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name class_name = class_name[4:] if class_name.startswith("Flax") else class_name
@@ -1256,7 +1285,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# if the model is in a pipeline module, then we load it from the pipeline # if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class # check that passed_class_obj has correct parent class
maybe_raise_or_warn( maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
) )
loaded_sub_model = passed_class_obj[name] loaded_sub_model = passed_class_obj[name]
@@ -1291,7 +1326,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} connected_pipes = {
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
}
load_kwargs = { load_kwargs = {
"cache_dir": cache_dir, "cache_dir": cache_dir,
"resume_download": resume_download, "resume_download": resume_download,
@@ -1316,10 +1353,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def get_connected_passed_kwargs(prefix): def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = { connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix k.replace(f"{prefix}_", ""): w
for k, w in passed_class_obj.items()
if k.split("_")[0] == prefix
} }
connected_passed_pipe_kwargs = { connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix k.replace(f"{prefix}_", ""): w
for k, w in passed_pipe_kwargs.items()
if k.split("_")[0] == prefix
} }
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}

View File

@@ -19,6 +19,7 @@ from packaging import version
from .. import __version__ from .. import __version__
from .constants import ( from .constants import (
_ACCEPTED_SINGLE_FILE_FORMATS,
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
@@ -83,7 +84,7 @@ from .import_utils import (
is_xformers_available, is_xformers_available,
requires_backends, requires_backends,
) )
from .loading_utils import load_image from .loading_utils import is_single_file_checkpoint, load_image
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .peft_utils import ( from .peft_utils import (

View File

@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -1,10 +1,28 @@
import os import os
from typing import Callable, Union from typing import Callable, Union
from urllib.parse import urlparse
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
import requests import requests
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
def is_single_file_checkpoint(filepath):
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
filepath = str(filepath)
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
if is_valid_url(filepath):
return True
elif os.path.isfile(filepath):
return True
return False
def load_image( def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None