mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
6 Commits
auto-pipel
...
add_more_t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36fe2c274f | ||
|
|
69bf3a4c3e | ||
|
|
908533cb47 | ||
|
|
03f819b051 | ||
|
|
4c19f4f346 | ||
|
|
cca76a6e22 |
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -23,10 +22,13 @@ from torch import nn
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
@@ -450,9 +452,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
# if xformers is installed try to use memory_efficient_attention by default
|
||||
if is_xformers_available():
|
||||
try:
|
||||
logger.info("xformers detected. Memory efficient attention is automatically enabled.")
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
logger.warning(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
@@ -19,6 +19,9 @@ from diffusers.utils.testing_utils import require_torch, torch_device
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineTesterMixin:
|
||||
"""
|
||||
@@ -115,10 +118,138 @@ class PipelineTesterMixin:
|
||||
self.assertLess(max_diff, 1e-5)
|
||||
|
||||
def test_pipeline_call_implements_required_args(self):
|
||||
required_args = ["num_inference_steps", "generator", "return_dict"]
|
||||
assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
|
||||
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
required_parameters.pop("self")
|
||||
required_parameters = set(required_parameters)
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
|
||||
for arg in required_args:
|
||||
self.assertTrue(arg in inspect.signature(self.pipeline_class.__call__).parameters)
|
||||
for param in required_parameters:
|
||||
assert param in ALLOWED_REQUIRED_ARGS
|
||||
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
|
||||
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
|
||||
for param in required_optional_params:
|
||||
assert param in optional_parameters
|
||||
|
||||
def test_inference_batch_image_pil_torch(self):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
allowed_image_args = [v for v in ALLOWED_REQUIRED_ARGS if v != "prompt"]
|
||||
|
||||
if set(allowed_image_args) - set(inputs.keys()) == set(allowed_image_args):
|
||||
# pipeline has no allowed required image args, so no need to test
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# batchify inputs
|
||||
for batch_size in [2, 4, 13]:
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in allowed_image_args:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
batched_inputs["num_inference_steps"] = 2
|
||||
batched_inputs["output_type"] = "np"
|
||||
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
for name in allowed_image_args:
|
||||
# convert pil to torch
|
||||
if name in batched_inputs:
|
||||
batched_inputs = torch.tensor(pipe.pil_to_numpy(value), dtype=torch.float32, device=torch_device)
|
||||
batched_inputs["num_inference_steps"] = 2
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
|
||||
output_torch_image = pipe(**batched_inputs)
|
||||
|
||||
max_diff = np.abs(output - output_torch_image).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_inference_batch_consistent(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# batchify inputs
|
||||
for batch_size in [2, 4, 13]:
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in ALLOWED_REQUIRED_ARGS:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
batched_inputs["num_inference_steps"] = 2
|
||||
batched_inputs["output_type"] = None
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
output = pipe(**batched_inputs)[0]
|
||||
|
||||
assert output.shape[0] == batch_size
|
||||
|
||||
def test_inference_generator_equality(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batch_size = 2
|
||||
for name, value in inputs.items():
|
||||
if name in ALLOWED_REQUIRED_ARGS:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
batched_inputs["num_inference_steps"] = 2
|
||||
batched_inputs["output_type"] = "np"
|
||||
seeds = [0, 44] # make sure length is equal to batch_size
|
||||
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
|
||||
batched_inputs["generator"] = generators
|
||||
|
||||
output_batch = pipe(**batched_inputs)[0]
|
||||
|
||||
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
|
||||
for i, seed in enumerate(seeds):
|
||||
inputs = {k: v[i] if isinstance(v, list) else v for k, v in batched_inputs.items()}
|
||||
inputs["generator"] = generators[i]
|
||||
|
||||
output_single = pipe(**inputs)[0]
|
||||
|
||||
assert np.abs(output_single - output_batch[i]).sum() < 1e-4
|
||||
|
||||
def test_num_inference_steps_consistent(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
Reference in New Issue
Block a user