mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
enable-tel
...
support-ot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfaa0d8ab4 | ||
|
|
f03ea10681 |
@@ -57,7 +57,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
|
||||
|
||||
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["unet"] = ["FromOriginalUNetMixin", "UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
@@ -72,7 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .autoencoder import FromOriginalVAEMixin
|
||||
from .controlnet import FromOriginalControlNetMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .unet import FromOriginalUNetMixin, UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
|
||||
@@ -50,6 +50,8 @@ if is_transformers_available():
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CONFIG_URLS = {
|
||||
@@ -1274,6 +1276,59 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
return text_model
|
||||
|
||||
|
||||
def create_diffusers_unet_from_stable_cascade(
|
||||
cls,
|
||||
pretrained_model_link_or_path,
|
||||
config,
|
||||
resume_download,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
cache_dir,
|
||||
local_files_only,
|
||||
revision,
|
||||
torch_dtype,
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = infer_stable_cascade_single_file_config(checkpoint)
|
||||
model_config = cls.load_config(**config, **kwargs)
|
||||
else:
|
||||
model_config = config
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(model_config, **kwargs)
|
||||
|
||||
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
|
||||
@@ -43,9 +43,9 @@ from ..utils import (
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .single_file_utils import (
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
infer_stable_cascade_single_file_config,
|
||||
load_single_file_model_checkpoint,
|
||||
create_diffusers_unet_from_stable_cascade,
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
@@ -66,6 +66,8 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
COMPATIBLE_SINGLE_FILE_CLASSES = ["StableCascadeUNet", "UNet2DConditionModel"]
|
||||
|
||||
|
||||
class UNet2DConditionLoadersMixin:
|
||||
"""
|
||||
@@ -912,8 +914,9 @@ class FromOriginalUNetMixin:
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
Instantiate a UNet from pretrained weights saved in the original `.ckpt`, `.bin`, or
|
||||
`.safetensors` format. The model is set in evaluation mode (`model.eval()`) by default.
|
||||
Currently supported checkpoints: StableCascade, SDXL, SD, Playground v2.5, etc.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
@@ -952,8 +955,10 @@ class FromOriginalUNetMixin:
|
||||
|
||||
"""
|
||||
class_name = cls.__name__
|
||||
if class_name != "StableCascadeUNet":
|
||||
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
|
||||
if class_name not in COMPATIBLE_SINGLE_FILE_CLASSES:
|
||||
raise ValueError(
|
||||
f"FromOriginalUNetMixin is currently only compatible with {', '.join(COMPATIBLE_SINGLE_FILE_CLASSES)}"
|
||||
)
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -965,39 +970,42 @@ class FromOriginalUNetMixin:
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = infer_stable_cascade_single_file_config(checkpoint)
|
||||
model_config = cls.load_config(**config, **kwargs)
|
||||
if class_name == "StableCascadeUNet":
|
||||
return create_diffusers_unet_from_stable_cascade(
|
||||
cls,
|
||||
pretrained_model_link_or_path,
|
||||
config,
|
||||
resume_download,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
cache_dir,
|
||||
local_files_only,
|
||||
revision,
|
||||
torch_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
model_config = config
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(model_config, **kwargs)
|
||||
|
||||
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
return model
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=kwargs.get("pipeline_class_name", None),
|
||||
original_config_file=kwargs.get("original_config_file", None),
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
return create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name=kwargs.get("pipeline_class_name", None),
|
||||
original_config=original_config,
|
||||
checkpoint=checkpoint,
|
||||
num_in_channels=kwargs.get("num_in_channels", 4),
|
||||
upcast_attention=kwargs.get("upcast_attention", None),
|
||||
extract_ema=kwargs.get("upcast_attention", False),
|
||||
image_size=kwargs.get("image_size", None),
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=kwargs.pop("model_type", None),
|
||||
)["unet"]
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders import FromOriginalUNetMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
@@ -66,7 +66,9 @@ class UNet2DConditionOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
class UNet2DConditionModel(
|
||||
ModelMixin, ConfigMixin, FromOriginalUNetMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
|
||||
):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
shaped output.
|
||||
|
||||
Reference in New Issue
Block a user