mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
8 Commits
transforme
...
v0.22.2-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249c06c12f | ||
|
|
0ac7d39830 | ||
|
|
d190959deb | ||
|
|
d5ff8f81b5 | ||
|
|
b4ca05fc26 | ||
|
|
a1d33fc9a5 | ||
|
|
1a4db89def | ||
|
|
df60b35e47 |
@@ -56,7 +56,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.22.0.dev0")
|
check_min_version("0.22.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -244,7 +244,7 @@ install_requires = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.22.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.22.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.22.0.dev0"
|
__version__ = "0.22.2"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -2390,7 +2390,7 @@ class LoraLoaderMixin:
|
|||||||
def set_adapters_for_text_encoder(
|
def set_adapters_for_text_encoder(
|
||||||
self,
|
self,
|
||||||
adapter_names: Union[List[str], str],
|
adapter_names: Union[List[str], str],
|
||||||
text_encoder: Optional[PreTrainedModel] = None,
|
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||||
text_encoder_weights: List[float] = None,
|
text_encoder_weights: List[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -2429,7 +2429,7 @@ class LoraLoaderMixin:
|
|||||||
)
|
)
|
||||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||||
|
|
||||||
def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
|
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
|
||||||
"""
|
"""
|
||||||
Disables the LoRA layers for the text encoder.
|
Disables the LoRA layers for the text encoder.
|
||||||
|
|
||||||
@@ -2446,7 +2446,7 @@ class LoraLoaderMixin:
|
|||||||
raise ValueError("Text Encoder not found.")
|
raise ValueError("Text Encoder not found.")
|
||||||
set_adapter_layers(text_encoder, enabled=False)
|
set_adapter_layers(text_encoder, enabled=False)
|
||||||
|
|
||||||
def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
|
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
|
||||||
"""
|
"""
|
||||||
Enables the LoRA layers for the text encoder.
|
Enables the LoRA layers for the text encoder.
|
||||||
|
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Incorrect norm")
|
raise ValueError("Incorrect norm")
|
||||||
|
|
||||||
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
|
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
|
||||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||||
|
|
||||||
attn_output = self.attn2(
|
attn_output = self.attn2(
|
||||||
|
|||||||
@@ -339,6 +339,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
elif self.is_input_vectorized:
|
elif self.is_input_vectorized:
|
||||||
hidden_states = self.latent_image_embedding(hidden_states)
|
hidden_states = self.latent_image_embedding(hidden_states)
|
||||||
elif self.is_input_patches:
|
elif self.is_input_patches:
|
||||||
|
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||||
hidden_states = self.pos_embed(hidden_states)
|
hidden_states = self.pos_embed(hidden_states)
|
||||||
|
|
||||||
if self.adaln_single is not None:
|
if self.adaln_single is not None:
|
||||||
@@ -425,7 +426,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
hidden_states = hidden_states.squeeze(1)
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
if self.adaln_single is None:
|
||||||
|
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,19 +1,40 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_dummy_objects = {}
|
||||||
"pipeline_latent_consistency_img2img": ["LatentConsistencyModelImg2ImgPipeline"],
|
_import_structure = {}
|
||||||
"pipeline_latent_consistency_text2img": ["LatentConsistencyModelPipeline"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
try:
|
||||||
from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"]
|
||||||
|
_import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import *
|
||||||
|
else:
|
||||||
|
from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
|
||||||
|
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
@@ -24,3 +45,6 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -353,13 +353,18 @@ def _get_pipeline_class(
|
|||||||
else:
|
else:
|
||||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||||
|
|
||||||
|
if repo_id is not None and hub_revision is not None:
|
||||||
|
# if we load the pipeline code from the Hub
|
||||||
|
# make sure to overwrite the `revison`
|
||||||
|
revision = hub_revision
|
||||||
|
|
||||||
return get_class_from_dynamic_module(
|
return get_class_from_dynamic_module(
|
||||||
custom_pipeline,
|
custom_pipeline,
|
||||||
module_file=file_name,
|
module_file=file_name,
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
revision=revision if hub_revision is None else hub_revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
if class_obj != DiffusionPipeline:
|
if class_obj != DiffusionPipeline:
|
||||||
|
|||||||
@@ -1 +1,48 @@
|
|||||||
from .pipeline_pixart_alpha import PixArtAlphaPipeline
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import *
|
||||||
|
else:
|
||||||
|
from .pipeline_pixart_alpha import PixArtAlphaPipeline
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
negative_prompt_embeds = None
|
negative_prompt_embeds = None
|
||||||
|
|
||||||
# Perform additional masking.
|
# Perform additional masking.
|
||||||
if mask_feature:
|
if mask_feature and prompt_embeds is None and negative_prompt_embeds is None:
|
||||||
prompt_embeds = prompt_embeds.unsqueeze(1)
|
prompt_embeds = prompt_embeds.unsqueeze(1)
|
||||||
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
|
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
|
||||||
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
|
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
|
||||||
|
|||||||
@@ -174,18 +174,99 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
inputs = self.get_dummy_inputs(device)
|
inputs = self.get_dummy_inputs(device)
|
||||||
image = pipe(**inputs).images
|
image = pipe(**inputs).images
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
print(torch.from_numpy(image_slice.flatten()))
|
|
||||||
|
|
||||||
self.assertEqual(image.shape, (1, 8, 8, 3))
|
self.assertEqual(image.shape, (1, 8, 8, 3))
|
||||||
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
|
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
|
||||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
self.assertLessEqual(max_diff, 1e-3)
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
|
def test_inference_non_square_images(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs, height=32, width=48).images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
self.assertEqual(image.shape, (1, 32, 48, 3))
|
||||||
|
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
|
||||||
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
|
def test_inference_with_embeddings_and_multiple_images(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
prompt = inputs["prompt"]
|
||||||
|
generator = inputs["generator"]
|
||||||
|
num_inference_steps = inputs["num_inference_steps"]
|
||||||
|
output_type = inputs["output_type"]
|
||||||
|
|
||||||
|
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
|
||||||
|
|
||||||
|
# inputs with prompt converted to embeddings
|
||||||
|
inputs = {
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"negative_prompt": None,
|
||||||
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"output_type": output_type,
|
||||||
|
"num_images_per_prompt": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# set all optional components to None
|
||||||
|
for optional_component in pipe._optional_components:
|
||||||
|
setattr(pipe, optional_component, None)
|
||||||
|
|
||||||
|
output = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
pipe.save_pretrained(tmpdir)
|
||||||
|
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||||
|
pipe_loaded.to(torch_device)
|
||||||
|
pipe_loaded.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
for optional_component in pipe._optional_components:
|
||||||
|
self.assertTrue(
|
||||||
|
getattr(pipe_loaded, optional_component) is None,
|
||||||
|
f"`{optional_component}` did not stay set to None after loading.",
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
generator = inputs["generator"]
|
||||||
|
num_inference_steps = inputs["num_inference_steps"]
|
||||||
|
output_type = inputs["output_type"]
|
||||||
|
|
||||||
|
# inputs with prompt converted to embeddings
|
||||||
|
inputs = {
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"negative_prompt": None,
|
||||||
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"output_type": output_type,
|
||||||
|
"num_images_per_prompt": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
output_loaded = pipe_loaded(**inputs)[0]
|
||||||
|
|
||||||
|
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||||
|
self.assertLess(max_diff, 1e-4)
|
||||||
|
|
||||||
def test_inference_batch_single_identical(self):
|
def test_inference_batch_single_identical(self):
|
||||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||||
|
|
||||||
|
|
||||||
# TODO: needs to be updated.
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user