mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60ab8fad16 | ||
|
|
d17240457f | ||
|
|
7512fc4df5 | ||
|
|
0c2f1ccc97 | ||
|
|
47f2d2c7be | ||
|
|
af85591593 | ||
|
|
29f15673ed |
@@ -56,7 +56,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ if __name__ == "__main__":
|
|||||||
pipe = download_from_original_stable_diffusion_ckpt(
|
pipe = download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path_or_dict=args.checkpoint_path,
|
checkpoint_path_or_dict=args.checkpoint_path,
|
||||||
original_config_file=args.original_config_file,
|
original_config_file=args.original_config_file,
|
||||||
|
config_files=args.config_files,
|
||||||
image_size=args.image_size,
|
image_size=args.image_size,
|
||||||
prediction_type=args.prediction_type,
|
prediction_type=args.prediction_type,
|
||||||
model_type=args.pipeline_type,
|
model_type=args.pipeline_type,
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -244,7 +244,7 @@ install_requires = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.21.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.21.0.dev0"
|
__version__ = "0.21.1"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ from .utils.import_utils import BACKENDS_MAPPING
|
|||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
@@ -627,7 +627,7 @@ class TextualInversionLoaderMixin:
|
|||||||
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
|
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||||
r"""
|
r"""
|
||||||
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
||||||
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||||
@@ -654,7 +654,7 @@ class TextualInversionLoaderMixin:
|
|||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
|
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||||
r"""
|
r"""
|
||||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
||||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
||||||
@@ -688,8 +688,8 @@ class TextualInversionLoaderMixin:
|
|||||||
self,
|
self,
|
||||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||||
token: Optional[Union[str, List[str]]] = None,
|
token: Optional[Union[str, List[str]]] = None,
|
||||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
|
||||||
text_encoder: Optional[PreTrainedModel] = None,
|
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -2098,6 +2098,7 @@ class FromSingleFileMixin:
|
|||||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
original_config_file = kwargs.pop("original_config_file", None)
|
original_config_file = kwargs.pop("original_config_file", None)
|
||||||
|
config_files = kwargs.pop("config_files", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
@@ -2215,6 +2216,7 @@ class FromSingleFileMixin:
|
|||||||
vae=vae,
|
vae=vae,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
|
config_files=config_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
|
|||||||
@@ -1255,7 +1255,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
self._all_hooks = []
|
self._all_hooks = []
|
||||||
hook = None
|
hook = None
|
||||||
for model_str in self.model_cpu_offload_seq.split("->"):
|
for model_str in self.model_cpu_offload_seq.split("->"):
|
||||||
model = all_model_components.pop(model_str)
|
model = all_model_components.pop(model_str, None)
|
||||||
if not isinstance(model, torch.nn.Module):
|
if not isinstance(model, torch.nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
|
config_url = None
|
||||||
|
|
||||||
# model_type = "v1"
|
# model_type = "v1"
|
||||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
if config_files is not None and "v1" in config_files:
|
||||||
|
original_config_file = config_files["v1"]
|
||||||
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||||
|
|
||||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||||
# model_type = "v2"
|
# model_type = "v2"
|
||||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
if config_files is not None and "v2" in config_files:
|
||||||
|
original_config_file = config_files["v2"]
|
||||||
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
if global_step == 110000:
|
if global_step == 110000:
|
||||||
# v2.1 needs to upcast attention
|
# v2.1 needs to upcast attention
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
elif key_name_sd_xl_base in checkpoint:
|
elif key_name_sd_xl_base in checkpoint:
|
||||||
# only base xl has two text embedders
|
# only base xl has two text embedders
|
||||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
if config_files is not None and "xl" in config_files:
|
||||||
|
original_config_file = config_files["xl"]
|
||||||
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||||
elif key_name_sd_xl_refiner in checkpoint:
|
elif key_name_sd_xl_refiner in checkpoint:
|
||||||
# only refiner xl has embedder and one text embedders
|
# only refiner xl has embedder and one text embedders
|
||||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
if config_files is not None and "xl_refiner" in config_files:
|
||||||
|
original_config_file = config_files["xl_refiner"]
|
||||||
original_config_file = BytesIO(requests.get(config_url).content)
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||||
|
if config_url is not None:
|
||||||
|
original_config_file = BytesIO(requests.get(config_url).content)
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
|
|||||||
@@ -50,13 +50,26 @@ class SafetyConfig(object):
|
|||||||
|
|
||||||
_dummy_objects = {}
|
_dummy_objects = {}
|
||||||
_additional_imports = {}
|
_additional_imports = {}
|
||||||
_import_structure = {
|
_import_structure = {}
|
||||||
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
|
|
||||||
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
|
|
||||||
"safety_checker": ["StableDiffusionSafetyChecker"],
|
|
||||||
}
|
|
||||||
_additional_imports.update({"SafetyConfig": SafetyConfig})
|
_additional_imports.update({"SafetyConfig": SafetyConfig})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure.update(
|
||||||
|
{
|
||||||
|
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
|
||||||
|
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
|
||||||
|
"safety_checker": ["StableDiffusionSafetyChecker"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
try:
|
try:
|
||||||
@@ -70,25 +83,16 @@ if TYPE_CHECKING:
|
|||||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
import sys
|
||||||
if not (is_transformers_available() and is_torch_available()):
|
|
||||||
raise OptionalDependencyNotAvailable()
|
|
||||||
except OptionalDependencyNotAvailable:
|
|
||||||
from ...utils import dummy_torch_and_transformers_objects
|
|
||||||
|
|
||||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
for name, value in _dummy_objects.items():
|
||||||
import sys
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
for name, value in _additional_imports.items():
|
||||||
sys.modules[__name__] = _LazyModule(
|
setattr(sys.modules[__name__], name, value)
|
||||||
__name__,
|
|
||||||
globals()["__file__"],
|
|
||||||
_import_structure,
|
|
||||||
module_spec=__spec__,
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, value in _dummy_objects.items():
|
|
||||||
setattr(sys.modules[__name__], name, value)
|
|
||||||
for name, value in _additional_imports.items():
|
|
||||||
setattr(sys.modules[__name__], name, value)
|
|
||||||
|
|||||||
@@ -47,3 +47,5 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -51,3 +51,6 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING:
|
|||||||
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
|
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
|
||||||
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
|
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
|
||||||
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
|
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -51,3 +50,6 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
Reference in New Issue
Block a user