mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 18:34:37 +08:00
Compare commits
12 Commits
custom-cod
...
support-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abf24f116e | ||
|
|
da441cefc6 | ||
|
|
4d2ca28e24 | ||
|
|
d7d757a387 | ||
|
|
e9e41981d7 | ||
|
|
807fa22bfc | ||
|
|
e3881c3bd9 | ||
|
|
45ab4399cc | ||
|
|
9c734f78e8 | ||
|
|
46f4c4399c | ||
|
|
64998bca1b | ||
|
|
34efcc2034 |
@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
|
|||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
vae = vae.to(torch_dtype)
|
vae = vae.to(torch_dtype)
|
||||||
|
|
||||||
|
vae.eval()
|
||||||
return vae
|
return vae
|
||||||
|
|||||||
@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
|
|||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
controlnet = controlnet.to(torch_dtype)
|
controlnet = controlnet.to(torch_dtype)
|
||||||
|
|
||||||
|
controlnet.eval()
|
||||||
return controlnet
|
return controlnet
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from ..utils import (
|
|||||||
_get_model_file,
|
_get_model_file,
|
||||||
deprecate,
|
deprecate,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
|
is_single_file_checkpoint,
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
|
||||||
|
|
||||||
if is_torch_version(">=", "1.9.0"):
|
if is_torch_version(">=", "1.9.0"):
|
||||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||||
else:
|
else:
|
||||||
@@ -497,6 +500,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||||
|
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||||
|
)
|
||||||
|
logger.info("Single file checkpoint detected...")
|
||||||
|
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||||
|
model = model.eval()
|
||||||
|
return model
|
||||||
|
else:
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ from ..utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_accelerate_version,
|
is_accelerate_version,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
|
is_single_file_checkpoint,
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||||
|
"StableDiffusionPipeline",
|
||||||
|
"StableDiffusionImg2ImgPipeline",
|
||||||
|
"StableDiffusionInpaintPipeline",
|
||||||
|
"StableDiffusionUpscalePipeline",
|
||||||
|
"StableDiffusionControlNetPipeline",
|
||||||
|
"StableDiffusionControlNetImg2ImgPipeline",
|
||||||
|
"StableDiffusionControlNetInpaintPipeline",
|
||||||
|
"StableDiffusionXLPipeline",
|
||||||
|
"StableDiffusionXLImg2ImgPipeline",
|
||||||
|
"StableDiffusionXLInpaintPipeline",
|
||||||
|
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||||
|
}
|
||||||
|
|
||||||
ALL_IMPORTABLE_CLASSES = {}
|
ALL_IMPORTABLE_CLASSES = {}
|
||||||
for library in LOADABLE_CLASSES:
|
for library in LOADABLE_CLASSES:
|
||||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||||
@@ -1056,6 +1071,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
>>> pipeline.scheduler = scheduler
|
>>> pipeline.scheduler = scheduler
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||||
|
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
|
raise ValueError(
|
||||||
|
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
||||||
|
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||||
|
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||||
|
)
|
||||||
|
logger.info("Single file checkpoint detected...")
|
||||||
|
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
else:
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
@@ -1242,7 +1269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
from diffusers import pipelines
|
from diffusers import pipelines
|
||||||
|
|
||||||
# 6. Load each module in the pipeline
|
# 6. Load each module in the pipeline
|
||||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
for name, (library_name, class_name) in logging.tqdm(
|
||||||
|
init_dict.items(), desc="Loading pipeline components..."
|
||||||
|
):
|
||||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||||
|
|
||||||
@@ -1256,7 +1285,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
# if the model is in a pipeline module, then we load it from the pipeline
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
# check that passed_class_obj has correct parent class
|
# check that passed_class_obj has correct parent class
|
||||||
maybe_raise_or_warn(
|
maybe_raise_or_warn(
|
||||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
library_name,
|
||||||
|
library,
|
||||||
|
class_name,
|
||||||
|
importable_classes,
|
||||||
|
passed_class_obj,
|
||||||
|
name,
|
||||||
|
is_pipeline_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded_sub_model = passed_class_obj[name]
|
loaded_sub_model = passed_class_obj[name]
|
||||||
@@ -1291,7 +1326,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
||||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
connected_pipes = {
|
||||||
|
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
|
||||||
|
}
|
||||||
load_kwargs = {
|
load_kwargs = {
|
||||||
"cache_dir": cache_dir,
|
"cache_dir": cache_dir,
|
||||||
"resume_download": resume_download,
|
"resume_download": resume_download,
|
||||||
@@ -1316,10 +1353,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
def get_connected_passed_kwargs(prefix):
|
def get_connected_passed_kwargs(prefix):
|
||||||
connected_passed_class_obj = {
|
connected_passed_class_obj = {
|
||||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
k.replace(f"{prefix}_", ""): w
|
||||||
|
for k, w in passed_class_obj.items()
|
||||||
|
if k.split("_")[0] == prefix
|
||||||
}
|
}
|
||||||
connected_passed_pipe_kwargs = {
|
connected_passed_pipe_kwargs = {
|
||||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
k.replace(f"{prefix}_", ""): w
|
||||||
|
for k, w in passed_pipe_kwargs.items()
|
||||||
|
if k.split("_")[0] == prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from packaging import version
|
|||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .constants import (
|
from .constants import (
|
||||||
|
_ACCEPTED_SINGLE_FILE_FORMATS,
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DEPRECATED_REVISION_ARGS,
|
DEPRECATED_REVISION_ARGS,
|
||||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||||
@@ -83,7 +84,7 @@ from .import_utils import (
|
|||||||
is_xformers_available,
|
is_xformers_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from .loading_utils import load_image
|
from .loading_utils import is_single_file_checkpoint, load_image
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from .outputs import BaseOutput
|
from .outputs import BaseOutput
|
||||||
from .peft_utils import (
|
from .peft_utils import (
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
|
|||||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||||
|
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
|
||||||
|
|
||||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||||
|
|||||||
@@ -1,10 +1,28 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import PIL.ImageOps
|
import PIL.ImageOps
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_file_checkpoint(filepath):
|
||||||
|
def is_valid_url(url):
|
||||||
|
result = urlparse(url)
|
||||||
|
if result.scheme and result.netloc:
|
||||||
|
return True
|
||||||
|
|
||||||
|
filepath = str(filepath)
|
||||||
|
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
|
||||||
|
if is_valid_url(filepath):
|
||||||
|
return True
|
||||||
|
elif os.path.isfile(filepath):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def load_image(
|
def load_image(
|
||||||
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user