Compare commits

...

6 Commits

Author SHA1 Message Date
Patrick von Platen
36fe2c274f more tests 2022-12-08 18:25:42 +00:00
Patrick von Platen
69bf3a4c3e up 2022-12-08 17:58:00 +00:00
Patrick von Platen
908533cb47 signature 2022-12-08 17:22:32 +00:00
Patrick von Platen
03f819b051 uP 2022-12-08 16:49:42 +00:00
Patrick von Platen
4c19f4f346 uP 2022-12-08 16:32:13 +00:00
Patrick von Platen
cca76a6e22 up 2022-12-08 14:55:42 +00:00
2 changed files with 140 additions and 6 deletions

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Optional
@@ -23,10 +22,13 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from ..utils.import_utils import is_xformers_available
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
@@ -450,9 +452,10 @@ class BasicTransformerBlock(nn.Module):
# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
logger.info("xformers detected. Memory efficient attention is automatically enabled.")
self.set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

View File

@@ -19,6 +19,9 @@ from diffusers.utils.testing_utils import require_torch, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
@require_torch
class PipelineTesterMixin:
"""
@@ -115,10 +118,138 @@ class PipelineTesterMixin:
self.assertLess(max_diff, 1e-5)
def test_pipeline_call_implements_required_args(self):
required_args = ["num_inference_steps", "generator", "return_dict"]
assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
parameters = inspect.signature(self.pipeline_class.__call__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
required_parameters.pop("self")
required_parameters = set(required_parameters)
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
for arg in required_args:
self.assertTrue(arg in inspect.signature(self.pipeline_class.__call__).parameters)
for param in required_parameters:
assert param in ALLOWED_REQUIRED_ARGS
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
for param in required_optional_params:
assert param in optional_parameters
def test_inference_batch_image_pil_torch(self):
inputs = self.get_dummy_inputs(torch_device)
allowed_image_args = [v for v in ALLOWED_REQUIRED_ARGS if v != "prompt"]
if set(allowed_image_args) - set(inputs.keys()) == set(allowed_image_args):
# pipeline has no allowed required image args, so no need to test
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# batchify inputs
for batch_size in [2, 4, 13]:
batched_inputs = {}
for name, value in inputs.items():
if name in allowed_image_args:
batched_inputs[name] = batch_size * [value]
else:
batched_inputs[name] = value
batched_inputs["num_inference_steps"] = 2
batched_inputs["output_type"] = "np"
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
output = pipe(**batched_inputs)
for name in allowed_image_args:
# convert pil to torch
if name in batched_inputs:
batched_inputs = torch.tensor(pipe.pil_to_numpy(value), dtype=torch.float32, device=torch_device)
batched_inputs["num_inference_steps"] = 2
batched_inputs["output_type"] = "np"
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
output_torch_image = pipe(**batched_inputs)
max_diff = np.abs(output - output_torch_image).max()
self.assertLess(max_diff, 1e-4)
def test_inference_batch_consistent(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# batchify inputs
for batch_size in [2, 4, 13]:
batched_inputs = {}
for name, value in inputs.items():
if name in ALLOWED_REQUIRED_ARGS:
# prompt is string
if name == "prompt":
len_prompt = len(value)
# make unequal batch sizes
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
# or else we have images
else:
batched_inputs[name] = batch_size * [value]
else:
batched_inputs[name] = value
batched_inputs["num_inference_steps"] = 2
batched_inputs["output_type"] = None
output = pipe(**batched_inputs)
assert len(output[0]) == batch_size
batched_inputs["output_type"] = "np"
output = pipe(**batched_inputs)[0]
assert output.shape[0] == batch_size
def test_inference_generator_equality(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# batchify inputs
batched_inputs = {}
batch_size = 2
for name, value in inputs.items():
if name in ALLOWED_REQUIRED_ARGS:
# prompt is string
if name == "prompt":
len_prompt = len(value)
# make unequal batch sizes
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
# or else we have images
else:
batched_inputs[name] = batch_size * [value]
else:
batched_inputs[name] = value
batched_inputs["num_inference_steps"] = 2
batched_inputs["output_type"] = "np"
seeds = [0, 44] # make sure length is equal to batch_size
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
batched_inputs["generator"] = generators
output_batch = pipe(**batched_inputs)[0]
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
for i, seed in enumerate(seeds):
inputs = {k: v[i] if isinstance(v, list) else v for k, v in batched_inputs.items()}
inputs["generator"] = generators[i]
output_single = pipe(**inputs)[0]
assert np.abs(output_single - output_batch[i]).sum() < 1e-4
def test_num_inference_steps_consistent(self):
components = self.get_dummy_components()