Compare commits

...

15 Commits

Author SHA1 Message Date
DN6
3aabef5de4 update 2025-07-24 22:18:15 +05:30
DN6
39be374591 update 2025-07-21 09:03:32 +05:30
DN6
54e17f3084 update 2025-07-21 08:56:47 +05:30
DN6
80702d222d update 2025-07-17 13:05:43 +05:30
DN6
625cc8ede8 update 2025-07-17 07:14:35 +05:30
yiyixuxu
a2a9e4eadb Merge branch 'modular-test' of github.com:huggingface/diffusers into modular-test 2025-07-16 12:03:09 +02:00
yiyixuxu
0998bd75ad up 2025-07-16 12:02:58 +02:00
yiyixuxu
5f560d05a2 up 2025-07-16 11:58:23 +02:00
yiyixuxu
4b7a9e9fa9 prepare_latents_inpaint always return noise and image_latents 2025-07-16 11:57:29 +02:00
yiyixuxu
d8fa2de36f remove more unused func 2025-07-16 04:29:27 +02:00
YiYi Xu
4df2739a5e Merge branch 'main' into modular-test 2025-07-15 16:27:33 -10:00
yiyixuxu
d92855ddf0 style 2025-07-16 04:26:27 +02:00
yiyixuxu
0a5c90ed47 add names property to pipeline blocks 2025-07-16 04:25:26 +02:00
yiyixuxu
0fa58127f8 make style 2025-07-15 03:05:36 +02:00
yiyixuxu
b165cf3742 rearrage the params to groups: default params /image params /batch params / callback params 2025-07-15 03:03:29 +02:00
8 changed files with 1037 additions and 39 deletions

141
.github/workflows/pr_modular_tests.yml vendored Normal file
View File

@@ -0,0 +1,141 @@
name: Fast PR tests for Modular
on:
pull_request:
branches: [main]
paths:
- "src/diffusers/modular_pipelines/**.py"
- "src/diffusers/models/modeling_utils.py"
- "src/diffusers/models/model_loading_utils.py"
- "src/diffusers/pipelines/pipeline_utils.py"
- "src/diffusers/pipeline_loading_utils.py"
- "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py"
- "tests/modular_pipelines/**.py"
- ".github/**.yml"
- "utils/**.py"
- "setup.py"
push:
branches:
- ci-*
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -479,6 +479,22 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return list(combined_dict.values())
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs]
@property
def intermediate_input_names(self) -> List[str]:
return [input_param.name for input_param in self.intermediate_inputs]
@property
def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs]
@property
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]
class PipelineBlock(ModularPipelineBlocks):
"""
@@ -2825,3 +2841,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
type_hint=type_hint,
**spec_dict,
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)

View File

@@ -744,8 +744,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=None,
is_strength_max=True,
add_noise=True,
return_noise=False,
return_image_latents=False,
):
shape = (
batch_size,
@@ -768,7 +766,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
elif latents is None and not is_strength_max:
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
@@ -786,13 +784,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
outputs = (latents, noise, image_latents)
return outputs
@@ -864,7 +856,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
components,
block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents,
@@ -878,8 +870,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise,
return_noise=True,
return_image_latents=False,
)
# 7. Prepare mask latent variables

View File

View File

@@ -0,0 +1,511 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from diffusers import (
ClassifierFreeGuidance,
ComponentsManager,
ModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers.loaders import ModularIPAdapterMixin
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
from ...models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ..test_modular_pipelines_common import (
ModularPipelineTesterMixin,
)
enable_full_determinism()
class SDXLModularTests:
"""
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
"""
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset(
[
"prompt",
"height",
"width",
"negative_prompt",
"cross_attention_kwargs",
"image",
"mask_image",
]
)
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_default_components(torch_dtype=torch_dtype)
return pipeline
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
assert image.shape == expected_image_shape
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
"Image Slice does not match expected slice"
)
class SDXLModularIPAdapterTests:
"""
This mixin is designed to test IP Adapter.
"""
def test_pipeline_inputs_and_blocks(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
assert "ip_adapter_image" in parameters, (
"`ip_adapter_image` argument must be supported by the `__call__` method"
)
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
intermediate_parameters = blocks.intermediate_input_names
assert "ip_adapter_image" not in parameters, (
"`ip_adapter_image` argument must be removed from the `__call__` method"
)
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
if "image" in parameters and "strength" in parameters:
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs, output="images")
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
assert max_diff_without_adapter_scale < expected_max_diff, (
"Output without ip-adapter must be same as normal inference"
)
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
"Output without multi-ip-adapter must be same as normal inference"
)
assert max_diff_with_multi_adapter_scale > 1e-2, (
"Output with multi-ip-adapter scale must be different from normal inference"
)
class SDXLModularControlNetTests:
"""
This mixin is designed to test ControlNet.
"""
def test_pipeline_inputs(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
assert "controlnet_conditioning_scale" in parameters, (
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
)
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
controlnet_embedder_scale_factor = 2
image = torch.randn(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
device=torch_device,
)
inputs["control_image"] = image
return inputs
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for ControlNet.
The following scenarios are tested:
- Single ControlNet with scale=0 should produce same output as no ControlNet.
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass without controlnet
inputs = self.get_dummy_inputs(torch_device)
output_without_controlnet = pipe(**inputs, output="images")
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but scale=0 which should have no effect
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 0.0
output_without_controlnet_scale = pipe(**inputs, output="images")
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but with scale of adapter weights
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 42.0
output_with_controlnet_scale = pipe(**inputs, output="images")
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
assert max_diff_without_controlnet_scale < expected_max_diff, (
"Output without controlnet must be same as normal inference"
)
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
def test_controlnet_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularGuiderTests:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
sd_pipe = self.get_pipeline(components_manager=cm)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_save_from_pretrained(self):
pipes = []
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname)
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
sd_pipe.load_default_components(torch_dtype=torch.float32)
sd_pipe.to(torch_device)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
class SDXLImg2ImgModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
class SDXLInpaintingModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
inputs["image"] = init_image
inputs["mask_image"] = mask_image
inputs["strength"] = 1.0
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.40872607,
0.38842705,
0.34893104,
0.47837183,
0.43792963,
0.5332134,
0.3716843,
0.47274873,
0.45000193,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

View File

@@ -0,0 +1,330 @@
import gc
import unittest
from typing import Callable, Union
import numpy as np
import torch
import diffusers
from diffusers.utils import logging
from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
@require_torch
class ModularPipelineTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each modular pipeline,
including:
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
- test_to_device: check if the pipeline's __call__ method can handle different devices
"""
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(
[
"generator",
]
)
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator
@property
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def repo(self) -> str:
raise NotImplementedError(
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
"See existing pipeline tests for reference."
)
def get_pipeline(self):
raise NotImplementedError(
"You need to implement `get_pipeline(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
"with non-configurable height and width arguments should set the attribute as "
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
"See existing pipeline tests for reference."
)
@property
def batch_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `batch_params` in the child test class. "
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
"image pipeline `negative_prompt` is not batched should set the attribute as "
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
"See existing pipeline tests for reference."
)
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_pipeline_call_signature(self):
pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names
intermediate_parameters = pipe.blocks.intermediate_input_names
optional_parameters = pipe.default_call_parameters
def _check_for_parameters(parameters, expected_parameters, param_type):
remaining_parameters = {param for param in parameters if param not in expected_parameters}
assert (
len(remaining_parameters) == 0
), f"Required {param_type} parameters not present: {remaining_parameters}"
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# prepare batched inputs
batched_inputs = []
for batch_size in batch_sizes:
batched_input = {}
batched_input.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_input[name] = batch_size * [value]
if batch_generator and "generator" in inputs:
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_input["batch_size"] = batch_size
batched_inputs.append(batched_input)
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(
device == torch_device for device in model_devices
), "All pipeline components are not on accelerator device"
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
if "num_images_per_prompt" not in pipe.blocks.input_names:
return
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
@require_accelerator
def test_components_auto_cpu_offload(self):
base_pipe = self.get_pipeline().to(torch_device)
for component in base_pipe.components:
assert component.device == torch_device
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
offload_pipe = self.get_pipeline(components_manager=cm)

View File

@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
]
)
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
IMAGE_VARIATION_PARAMS = frozenset(
[
"image",
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
]
)
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[
"prompt",
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
]
)
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
# Text guided image variation with an image mask
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
]
)
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_PARAMS = frozenset(
[
# image variation with an image mask
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
]
)
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
"example_image",
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
]
)
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
]
)
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
# image params
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
# batch params
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
# callback params
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])