Compare commits

...

12 Commits

Author SHA1 Message Date
sayakpaul
abf24f116e resolve conflicts. 2024-03-06 09:50:39 +05:30
Sayak Paul
da441cefc6 Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-26 15:38:43 +05:30
sayakpaul
4d2ca28e24 ditto for pipelines. 2024-02-26 15:16:19 +05:30
sayakpaul
d7d757a387 make single file loader cleaner models 2024-02-26 15:14:35 +05:30
sayakpaul
e9e41981d7 fix: posix 2024-02-19 13:18:34 +05:30
Sayak Paul
807fa22bfc Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-19 12:19:54 +05:30
Sayak Paul
e3881c3bd9 Merge branch 'main' into support-single-file-from-from_pretrained 2024-02-18 14:47:35 +05:30
sayakpaul
45ab4399cc Empty-Commit 2024-02-15 17:14:58 +05:30
sayakpaul
9c734f78e8 fix: condition for loading_info 2024-02-15 17:02:54 +05:30
sayakpaul
46f4c4399c add proper error handling through loadable classes check. 2024-02-15 16:56:29 +05:30
sayakpaul
64998bca1b support models too 2024-02-15 16:29:28 +05:30
sayakpaul
34efcc2034 feat: support single file checkpoint from from_pretrained() 2024-02-15 15:58:47 +05:30
7 changed files with 577 additions and 501 deletions

View File

@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
if torch_dtype is not None:
vae = vae.to(torch_dtype)
vae.eval()
return vae

View File

@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)
controlnet.eval()
return controlnet

View File

@@ -39,6 +39,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_single_file_checkpoint,
is_torch_version,
logging,
)
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
logger = logging.get_logger(__name__)
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
@@ -497,6 +500,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
if is_single_file_checkpoint(pretrained_model_name_or_path):
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
)
logger.info("Single file checkpoint detected...")
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
model = model.eval()
return model
else:
cache_dir = kwargs.pop("cache_dir", None)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)

View File

@@ -57,6 +57,7 @@ from ..utils import (
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_single_file_checkpoint,
is_torch_version,
is_transformers_available,
logging,
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
},
}
SINGLE_FILE_LOADABLE_CLASSES = {
"StableDiffusionPipeline",
"StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
}
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@@ -1056,6 +1071,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> pipeline.scheduler = scheduler
```
"""
if is_single_file_checkpoint(pretrained_model_name_or_path):
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
)
logger.info("Single file checkpoint detected...")
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
return model
else:
cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
@@ -1242,7 +1269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
from diffusers import pipelines
# 6. Load each module in the pipeline
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
for name, (library_name, class_name) in logging.tqdm(
init_dict.items(), desc="Loading pipeline components..."
):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
@@ -1256,7 +1285,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
)
loaded_sub_model = passed_class_obj[name]
@@ -1291,7 +1326,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
connected_pipes = {
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
}
load_kwargs = {
"cache_dir": cache_dir,
"resume_download": resume_download,
@@ -1316,10 +1353,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
k.replace(f"{prefix}_", ""): w
for k, w in passed_class_obj.items()
if k.split("_")[0] == prefix
}
connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
k.replace(f"{prefix}_", ""): w
for k, w in passed_pipe_kwargs.items()
if k.split("_")[0] == prefix
}
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}

View File

@@ -19,6 +19,7 @@ from packaging import version
from .. import __version__
from .constants import (
_ACCEPTED_SINGLE_FILE_FORMATS,
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
@@ -83,7 +84,7 @@ from .import_utils import (
is_xformers_available,
requires_backends,
)
from .loading_utils import load_image
from .loading_utils import is_single_file_checkpoint, load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (

View File

@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -1,10 +1,28 @@
import os
from typing import Callable, Union
from urllib.parse import urlparse
import PIL.Image
import PIL.ImageOps
import requests
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
def is_single_file_checkpoint(filepath):
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
filepath = str(filepath)
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
if is_valid_url(filepath):
return True
elif os.path.isfile(filepath):
return True
return False
def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None