mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
15 Commits
custom-cod
...
modular-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3aabef5de4 | ||
|
|
39be374591 | ||
|
|
54e17f3084 | ||
|
|
80702d222d | ||
|
|
625cc8ede8 | ||
|
|
a2a9e4eadb | ||
|
|
0998bd75ad | ||
|
|
5f560d05a2 | ||
|
|
4b7a9e9fa9 | ||
|
|
d8fa2de36f | ||
|
|
4df2739a5e | ||
|
|
d92855ddf0 | ||
|
|
0a5c90ed47 | ||
|
|
0fa58127f8 | ||
|
|
b165cf3742 |
141
.github/workflows/pr_modular_tests.yml
vendored
Normal file
141
.github/workflows/pr_modular_tests.yml
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/diffusers/modular_pipelines/**.py"
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/modular_pipelines/**.py"
|
||||
- ".github/**.yml"
|
||||
- "utils/**.py"
|
||||
- "setup.py"
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -479,6 +479,22 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs]
|
||||
|
||||
@property
|
||||
def intermediate_input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.intermediate_inputs]
|
||||
|
||||
@property
|
||||
def intermediate_output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.intermediate_outputs]
|
||||
|
||||
@property
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
"""
|
||||
@@ -2825,3 +2841,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
type_hint=type_hint,
|
||||
**spec_dict,
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
@@ -744,8 +744,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +766,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +784,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +856,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +870,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
|
||||
0
tests/modular_pipelines/__init__.py
Normal file
0
tests/modular_pipelines/__init__.py
Normal file
@@ -0,0 +1,511 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
ComponentsManager,
|
||||
ModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs_and_blocks(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
|
||||
assert "ip_adapter_image" in parameters, (
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
intermediate_parameters = blocks.intermediate_input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
|
||||
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_multi_adapter_scale > 1e-2, (
|
||||
"Output with multi-ip-adapter scale must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
|
||||
assert "controlnet_conditioning_scale" in parameters, (
|
||||
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = torch.randn(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
device=torch_device,
|
||||
)
|
||||
inputs["control_image"] = image
|
||||
return inputs
|
||||
|
||||
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for ControlNet.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single ControlNet with scale=0 should produce same output as no ControlNet.
|
||||
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
sd_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
sd_pipe = self.get_pipeline(components_manager=cm)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_xl_save_from_pretrained(self):
|
||||
pipes = []
|
||||
sd_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
sd_pipe.load_default_components(torch_dtype=torch.float32)
|
||||
sd_pipe.to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["mask_image"] = mask_image
|
||||
inputs["strength"] = 1.0
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
330
tests/modular_pipelines/test_modular_pipelines_common.py
Normal file
330
tests/modular_pipelines/test_modular_pipelines_common.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import gc
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
|
||||
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
|
||||
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
|
||||
- test_to_device: check if the pipeline's __call__ method can handle different devices
|
||||
"""
|
||||
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `params` in the child test class. "
|
||||
"`params` are checked for if all values are present in `__call__`'s signature."
|
||||
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
|
||||
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
|
||||
"image pipelines, including prompts and prompt embedding overrides."
|
||||
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
|
||||
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
|
||||
"with non-configurable height and width arguments should set the attribute as "
|
||||
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `batch_params` in the child test class. "
|
||||
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
|
||||
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
|
||||
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
|
||||
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
|
||||
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
|
||||
"image pipeline `negative_prompt` is not batched should set the attribute as "
|
||||
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
intermediate_parameters = pipe.blocks.intermediate_input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert (
|
||||
len(remaining_parameters) == 0
|
||||
), f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(
|
||||
device == torch_device for device in model_devices
|
||||
), "All pipeline components are not on accelerator device"
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
for component in base_pipe.components:
|
||||
assert component.device == torch_device
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# Text guided image variation with an image mask
|
||||
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# image variation with an image mask
|
||||
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
"example_image",
|
||||
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
# image params
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
|
||||
# batch params
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
|
||||
|
||||
# callback params
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
Reference in New Issue
Block a user