Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
bfaa0d8ab4 retrieve the unet from create_diffusers_unet_model_from_ldm 2024-03-14 15:52:44 +05:30
sayakpaul
f03ea10681 refactor unet single file loading a bit. 2024-03-14 15:33:40 +05:30
4 changed files with 111 additions and 46 deletions

View File

@@ -57,7 +57,7 @@ if is_torch_available():
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["unet"] = ["FromOriginalUNetMixin", "UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"]
@@ -72,7 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlNetMixin
from .unet import UNet2DConditionLoadersMixin
from .unet import FromOriginalUNetMixin, UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():

View File

@@ -50,6 +50,8 @@ if is_transformers_available():
if is_accelerate_available():
from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CONFIG_URLS = {
@@ -1274,6 +1276,59 @@ def create_text_encoder_from_open_clip_checkpoint(
return text_model
def create_diffusers_unet_from_stable_cascade(
cls,
pretrained_model_link_or_path,
config,
resume_download,
force_download,
proxies,
token,
cache_dir,
local_files_only,
revision,
torch_dtype,
**kwargs,
):
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
def create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,

View File

@@ -43,9 +43,9 @@ from ..utils import (
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
create_diffusers_unet_from_stable_cascade,
create_diffusers_unet_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)
from .utils import AttnProcsLayers
@@ -66,6 +66,8 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
COMPATIBLE_SINGLE_FILE_CLASSES = ["StableCascadeUNet", "UNet2DConditionModel"]
class UNet2DConditionLoadersMixin:
"""
@@ -912,8 +914,9 @@ class FromOriginalUNetMixin:
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Instantiate a UNet from pretrained weights saved in the original `.ckpt`, `.bin`, or
`.safetensors` format. The model is set in evaluation mode (`model.eval()`) by default.
Currently supported checkpoints: StableCascade, SDXL, SD, Playground v2.5, etc.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
@@ -952,8 +955,10 @@ class FromOriginalUNetMixin:
"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
if class_name not in COMPATIBLE_SINGLE_FILE_CLASSES:
raise ValueError(
f"FromOriginalUNetMixin is currently only compatible with {', '.join(COMPATIBLE_SINGLE_FILE_CLASSES)}"
)
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
@@ -965,39 +970,42 @@ class FromOriginalUNetMixin:
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
if class_name == "StableCascadeUNet":
return create_diffusers_unet_from_stable_cascade(
cls,
pretrained_model_link_or_path,
config,
resume_download,
force_download,
proxies,
token,
cache_dir,
local_files_only,
revision,
torch_dtype,
**kwargs,
)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=kwargs.get("pipeline_class_name", None),
original_config_file=kwargs.get("original_config_file", None),
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
return create_diffusers_unet_model_from_ldm(
pipeline_class_name=kwargs.get("pipeline_class_name", None),
original_config=original_config,
checkpoint=checkpoint,
num_in_channels=kwargs.get("num_in_channels", 4),
upcast_attention=kwargs.get("upcast_attention", None),
extract_ema=kwargs.get("upcast_attention", False),
image_size=kwargs.get("image_size", None),
torch_dtype=torch_dtype,
model_type=kwargs.pop("model_type", None),
)["unet"]

View File

@@ -19,7 +19,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders import FromOriginalUNetMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
@@ -66,7 +66,9 @@ class UNet2DConditionOutput(BaseOutput):
sample: torch.FloatTensor = None
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalUNetMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.