mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 07:24:32 +08:00
Compare commits
6 Commits
disable-mm
...
add_more_t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36fe2c274f | ||
|
|
69bf3a4c3e | ||
|
|
908533cb47 | ||
|
|
03f819b051 | ||
|
|
4c19f4f346 | ||
|
|
cca76a6e22 |
@@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -23,10 +22,13 @@ from torch import nn
|
|||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from ..models.embeddings import ImagePositionalEmbeddings
|
from ..models.embeddings import ImagePositionalEmbeddings
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput, logging
|
||||||
from ..utils.import_utils import is_xformers_available
|
from ..utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Transformer2DModelOutput(BaseOutput):
|
class Transformer2DModelOutput(BaseOutput):
|
||||||
"""
|
"""
|
||||||
@@ -450,9 +452,10 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
# if xformers is installed try to use memory_efficient_attention by default
|
# if xformers is installed try to use memory_efficient_attention by default
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
try:
|
try:
|
||||||
|
logger.info("xformers detected. Memory efficient attention is automatically enabled.")
|
||||||
self.set_use_memory_efficient_attention_xformers(True)
|
self.set_use_memory_efficient_attention_xformers(True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(
|
logger.warning(
|
||||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||||
f" correctly and a GPU is available: {e}"
|
f" correctly and a GPU is available: {e}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ from diffusers.utils.testing_utils import require_torch, torch_device
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class PipelineTesterMixin:
|
class PipelineTesterMixin:
|
||||||
"""
|
"""
|
||||||
@@ -115,10 +118,138 @@ class PipelineTesterMixin:
|
|||||||
self.assertLess(max_diff, 1e-5)
|
self.assertLess(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_pipeline_call_implements_required_args(self):
|
def test_pipeline_call_implements_required_args(self):
|
||||||
required_args = ["num_inference_steps", "generator", "return_dict"]
|
assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
|
||||||
|
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||||
|
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||||
|
required_parameters.pop("self")
|
||||||
|
required_parameters = set(required_parameters)
|
||||||
|
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||||
|
|
||||||
for arg in required_args:
|
for param in required_parameters:
|
||||||
self.assertTrue(arg in inspect.signature(self.pipeline_class.__call__).parameters)
|
assert param in ALLOWED_REQUIRED_ARGS
|
||||||
|
|
||||||
|
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||||
|
|
||||||
|
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
|
||||||
|
for param in required_optional_params:
|
||||||
|
assert param in optional_parameters
|
||||||
|
|
||||||
|
def test_inference_batch_image_pil_torch(self):
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
allowed_image_args = [v for v in ALLOWED_REQUIRED_ARGS if v != "prompt"]
|
||||||
|
|
||||||
|
if set(allowed_image_args) - set(inputs.keys()) == set(allowed_image_args):
|
||||||
|
# pipeline has no allowed required image args, so no need to test
|
||||||
|
return
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
# batchify inputs
|
||||||
|
for batch_size in [2, 4, 13]:
|
||||||
|
batched_inputs = {}
|
||||||
|
for name, value in inputs.items():
|
||||||
|
if name in allowed_image_args:
|
||||||
|
batched_inputs[name] = batch_size * [value]
|
||||||
|
else:
|
||||||
|
batched_inputs[name] = value
|
||||||
|
|
||||||
|
batched_inputs["num_inference_steps"] = 2
|
||||||
|
batched_inputs["output_type"] = "np"
|
||||||
|
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
|
||||||
|
output = pipe(**batched_inputs)
|
||||||
|
|
||||||
|
for name in allowed_image_args:
|
||||||
|
# convert pil to torch
|
||||||
|
if name in batched_inputs:
|
||||||
|
batched_inputs = torch.tensor(pipe.pil_to_numpy(value), dtype=torch.float32, device=torch_device)
|
||||||
|
batched_inputs["num_inference_steps"] = 2
|
||||||
|
batched_inputs["output_type"] = "np"
|
||||||
|
|
||||||
|
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
|
||||||
|
output_torch_image = pipe(**batched_inputs)
|
||||||
|
|
||||||
|
max_diff = np.abs(output - output_torch_image).max()
|
||||||
|
self.assertLess(max_diff, 1e-4)
|
||||||
|
|
||||||
|
def test_inference_batch_consistent(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
# batchify inputs
|
||||||
|
for batch_size in [2, 4, 13]:
|
||||||
|
batched_inputs = {}
|
||||||
|
for name, value in inputs.items():
|
||||||
|
if name in ALLOWED_REQUIRED_ARGS:
|
||||||
|
# prompt is string
|
||||||
|
if name == "prompt":
|
||||||
|
len_prompt = len(value)
|
||||||
|
# make unequal batch sizes
|
||||||
|
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||||
|
# or else we have images
|
||||||
|
else:
|
||||||
|
batched_inputs[name] = batch_size * [value]
|
||||||
|
else:
|
||||||
|
batched_inputs[name] = value
|
||||||
|
|
||||||
|
batched_inputs["num_inference_steps"] = 2
|
||||||
|
batched_inputs["output_type"] = None
|
||||||
|
output = pipe(**batched_inputs)
|
||||||
|
|
||||||
|
assert len(output[0]) == batch_size
|
||||||
|
|
||||||
|
batched_inputs["output_type"] = "np"
|
||||||
|
output = pipe(**batched_inputs)[0]
|
||||||
|
|
||||||
|
assert output.shape[0] == batch_size
|
||||||
|
|
||||||
|
def test_inference_generator_equality(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
# batchify inputs
|
||||||
|
batched_inputs = {}
|
||||||
|
batch_size = 2
|
||||||
|
for name, value in inputs.items():
|
||||||
|
if name in ALLOWED_REQUIRED_ARGS:
|
||||||
|
# prompt is string
|
||||||
|
if name == "prompt":
|
||||||
|
len_prompt = len(value)
|
||||||
|
# make unequal batch sizes
|
||||||
|
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||||
|
# or else we have images
|
||||||
|
else:
|
||||||
|
batched_inputs[name] = batch_size * [value]
|
||||||
|
else:
|
||||||
|
batched_inputs[name] = value
|
||||||
|
|
||||||
|
batched_inputs["num_inference_steps"] = 2
|
||||||
|
batched_inputs["output_type"] = "np"
|
||||||
|
seeds = [0, 44] # make sure length is equal to batch_size
|
||||||
|
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
|
||||||
|
batched_inputs["generator"] = generators
|
||||||
|
|
||||||
|
output_batch = pipe(**batched_inputs)[0]
|
||||||
|
|
||||||
|
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
|
||||||
|
for i, seed in enumerate(seeds):
|
||||||
|
inputs = {k: v[i] if isinstance(v, list) else v for k, v in batched_inputs.items()}
|
||||||
|
inputs["generator"] = generators[i]
|
||||||
|
|
||||||
|
output_single = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
assert np.abs(output_single - output_batch[i]).sum() < 1e-4
|
||||||
|
|
||||||
def test_num_inference_steps_consistent(self):
|
def test_num_inference_steps_consistent(self):
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
|
|||||||
Reference in New Issue
Block a user