mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 17:34:44 +08:00
Compare commits
12 Commits
enable-cp-
...
support-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abf24f116e | ||
|
|
da441cefc6 | ||
|
|
4d2ca28e24 | ||
|
|
d7d757a387 | ||
|
|
e9e41981d7 | ||
|
|
807fa22bfc | ||
|
|
e3881c3bd9 | ||
|
|
45ab4399cc | ||
|
|
9c734f78e8 | ||
|
|
46f4c4399c | ||
|
|
64998bca1b | ||
|
|
34efcc2034 |
@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
vae.eval()
|
||||
return vae
|
||||
|
||||
@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
controlnet.eval()
|
||||
return controlnet
|
||||
|
||||
@@ -39,6 +39,7 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_single_file_checkpoint,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
else:
|
||||
@@ -497,6 +500,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||
)
|
||||
logger.info("Single file checkpoint detected...")
|
||||
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||
model = model.eval()
|
||||
return model
|
||||
else:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
|
||||
@@ -57,6 +57,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_peft_available,
|
||||
is_single_file_checkpoint,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
|
||||
},
|
||||
}
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
@@ -1056,6 +1071,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
>>> pipeline.scheduler = scheduler
|
||||
```
|
||||
"""
|
||||
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||
)
|
||||
logger.info("Single file checkpoint detected...")
|
||||
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||
return model
|
||||
|
||||
else:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -1242,7 +1269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
init_dict.items(), desc="Loading pipeline components..."
|
||||
):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
@@ -1256,7 +1285,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
library_name,
|
||||
library,
|
||||
class_name,
|
||||
importable_classes,
|
||||
passed_class_obj,
|
||||
name,
|
||||
is_pipeline_module,
|
||||
)
|
||||
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
@@ -1291,7 +1326,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
||||
connected_pipes = {
|
||||
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
|
||||
}
|
||||
load_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"resume_download": resume_download,
|
||||
@@ -1316,10 +1353,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
def get_connected_passed_kwargs(prefix):
|
||||
connected_passed_class_obj = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
||||
k.replace(f"{prefix}_", ""): w
|
||||
for k, w in passed_class_obj.items()
|
||||
if k.split("_")[0] == prefix
|
||||
}
|
||||
connected_passed_pipe_kwargs = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
||||
k.replace(f"{prefix}_", ""): w
|
||||
for k, w in passed_pipe_kwargs.items()
|
||||
if k.split("_")[0] == prefix
|
||||
}
|
||||
|
||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||
|
||||
@@ -19,6 +19,7 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from .constants import (
|
||||
_ACCEPTED_SINGLE_FILE_FORMATS,
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
@@ -83,7 +84,7 @@ from .import_utils import (
|
||||
is_xformers_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import load_image
|
||||
from .loading_utils import is_single_file_checkpoint, load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
|
||||
@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -1,10 +1,28 @@
|
||||
import os
|
||||
from typing import Callable, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
|
||||
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
|
||||
|
||||
|
||||
def is_single_file_checkpoint(filepath):
|
||||
def is_valid_url(url):
|
||||
result = urlparse(url)
|
||||
if result.scheme and result.netloc:
|
||||
return True
|
||||
|
||||
filepath = str(filepath)
|
||||
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
|
||||
if is_valid_url(filepath):
|
||||
return True
|
||||
elif os.path.isfile(filepath):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_image(
|
||||
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
||||
|
||||
Reference in New Issue
Block a user