Compare commits

...

7 Commits

Author SHA1 Message Date
Patrick von Platen
60ab8fad16 Patch release: v0.21.1 2023-09-14 13:06:57 +02:00
Patrick von Platen
d17240457f [Import] Add missing settings / Correct some dummy imports (#5036)
* [Import] Add missing settings

* up

* up

* up
2023-09-14 12:47:55 +02:00
Vladimir Mandic
7512fc4df5 allow loading of sd models from safetensors without online lookups using local config files (#5019)
finish config_files implementation
2023-09-14 12:47:41 +02:00
Patrick von Platen
0c2f1ccc97 [Import] Don't force transformers to be installed (#5035)
* [Import] Don't force transformers to be installed

* make style
2023-09-14 12:47:34 +02:00
Dhruv Nair
47f2d2c7be Fix model offload bug when key isn't present (#5030)
* fix model offload bug when key isn't present

* make style
2023-09-14 12:47:25 +02:00
Patrick von Platen
af85591593 Patch release: v0.21.1 2023-09-14 12:46:39 +02:00
Patrick von Platen
29f15673ed Release: v0.21.0 2023-09-13 15:58:24 +02:00
29 changed files with 86 additions and 60 deletions

View File

@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
# Cache compiled models across invocations of this script. # Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -154,6 +154,7 @@ if __name__ == "__main__":
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path, checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size, image_size=args.image_size,
prediction_type=args.prediction_type, prediction_type=args.prediction_type,
model_type=args.pipeline_type, model_type=args.pipeline_type,

View File

@@ -244,7 +244,7 @@ install_requires = [
setup( setup(
name="diffusers", name="diffusers",
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.21.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.", description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.21.0.dev0" __version__ = "0.21.1"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING

View File

@@ -41,7 +41,7 @@ from .utils.import_utils import BACKENDS_MAPPING
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
@@ -627,7 +627,7 @@ class TextualInversionLoaderMixin:
Load textual inversion tokens and embeddings to the tokenizer and text encoder. Load textual inversion tokens and embeddings to the tokenizer and text encoder.
""" """
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
r""" r"""
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
@@ -654,7 +654,7 @@ class TextualInversionLoaderMixin:
return prompts return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
r""" r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
@@ -688,8 +688,8 @@ class TextualInversionLoaderMixin:
self, self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None, token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional[PreTrainedTokenizer] = None, tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
text_encoder: Optional[PreTrainedModel] = None, text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
**kwargs, **kwargs,
): ):
r""" r"""
@@ -2098,6 +2098,7 @@ class FromSingleFileMixin:
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None) original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
@@ -2215,6 +2216,7 @@ class FromSingleFileMixin:
vae=vae, vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
original_config_file=original_config_file, original_config_file=original_config_file,
config_files=config_files,
) )
if torch_dtype is not None: if torch_dtype is not None:

View File

@@ -1255,7 +1255,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self._all_hooks = [] self._all_hooks = []
hook = None hook = None
for model_str in self.model_cpu_offload_seq.split("->"): for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str) model = all_model_components.pop(model_str, None)
if not isinstance(model, torch.nn.Module): if not isinstance(model, torch.nn.Module):
continue continue

View File

@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
config_url = None
# model_type = "v1" # model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if config_files is not None and "v1" in config_files:
original_config_file = config_files["v1"]
else:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2" # model_type = "v2"
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if config_files is not None and "v2" in config_files:
original_config_file = config_files["v2"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
elif key_name_sd_xl_base in checkpoint: elif key_name_sd_xl_base in checkpoint:
# only base xl has two text embedders # only base xl has two text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" if config_files is not None and "xl" in config_files:
original_config_file = config_files["xl"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
elif key_name_sd_xl_refiner in checkpoint: elif key_name_sd_xl_refiner in checkpoint:
# only refiner xl has embedder and one text embedders # only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" if config_files is not None and "xl_refiner" in config_files:
original_config_file = config_files["xl_refiner"]
original_config_file = BytesIO(requests.get(config_url).content) else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)

View File

@@ -50,13 +50,26 @@ class SafetyConfig(object):
_dummy_objects = {} _dummy_objects = {}
_additional_imports = {} _additional_imports = {}
_import_structure = { _import_structure = {}
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
_additional_imports.update({"SafetyConfig": SafetyConfig}) _additional_imports.update({"SafetyConfig": SafetyConfig})
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
)
if TYPE_CHECKING: if TYPE_CHECKING:
try: try:
@@ -70,25 +83,16 @@ if TYPE_CHECKING:
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
else: else:
try: import sys
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
else: for name, value in _dummy_objects.items():
import sys setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
sys.modules[__name__] = _LazyModule( setattr(sys.modules[__name__], name, value)
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -47,3 +47,5 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -51,3 +51,6 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -41,7 +41,6 @@ if TYPE_CHECKING:
from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
else: else:
import sys import sys
@@ -51,3 +50,6 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)