Compare commits

..

3 Commits

Author SHA1 Message Date
Álvaro Somoza
d5389bdd71 Merge branch 'main' into cnet-union-multiple-fixes 2025-07-10 14:19:37 -04:00
Álvaro Somoza
c99f9c1799 Merge branch 'main' into cnet-union-multiple-fixes 2025-07-09 14:32:53 -04:00
Álvaro Somoza
e797122f64 fixes 2025-07-08 14:15:52 -04:00
33 changed files with 339 additions and 1474 deletions

View File

@@ -1,141 +0,0 @@
name: Fast PR tests for Modular
on:
pull_request:
branches: [main]
paths:
- "src/diffusers/modular_pipelines/**.py"
- "src/diffusers/models/modeling_utils.py"
- "src/diffusers/models/model_loading_utils.py"
- "src/diffusers/pipelines/pipeline_utils.py"
- "src/diffusers/pipeline_loading_utils.py"
- "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py"
- "tests/modular_pipelines/**.py"
- ".github/**.yml"
- "utils/**.py"
- "setup.py"
push:
branches:
- ci-*
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -174,36 +174,39 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
### Regional compilation ### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 810x.
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below. [Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
```py ```py
# pip install -U diffusers # pip install -U diffusers
import torch import torch
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained( pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16, torch_dtype=torch.float16,
).to("cuda") ).to("cuda")
# compile only the repeated transformer layers inside the UNet # Compile only the repeated Transformer layers inside the UNet
pipeline.unet.compile_repeated_blocks(fullgraph=True) pipe.unet.compile_repeated_blocks(fullgraph=True)
``` ```
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile. To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
```py ```py
class MyUNet(ModelMixin): class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default _repeated_blocks = ("Transformer2DModel",) # ← compiled by default
``` ```
> [!TIP] For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py ```py
# pip install -U accelerate # pip install -U accelerate
@@ -216,8 +219,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda") ).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True) pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
``` ```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
### Graph breaks ### Graph breaks
@@ -293,9 +296,3 @@ An input is projected into three subspaces, represented by the projection matric
```py ```py
pipeline.fuse_qkv_projections() pipeline.fuse_qkv_projections()
``` ```
## Resources
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).

View File

@@ -14,9 +14,6 @@ specific language governing permissions and limitations under the License.
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading). Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
> [!TIP]
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU. For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound. For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -28,7 +25,7 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 | | quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 | | quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 | | quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small> <small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes. This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.

View File

@@ -1330,7 +1330,7 @@ def main(args):
# controlnet(s) inference # controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
controlnet_image = vae.encode(controlnet_image).latent_dist.sample() controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor controlnet_image = controlnet_image * vae.config.scaling_factor
control_block_res_samples = controlnet( control_block_res_samples = controlnet(
hidden_states=noisy_model_input, hidden_states=noisy_model_input,

View File

@@ -1,10 +1,10 @@
# This file was autogenerated by uv via the following command: # This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt # uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.4.3
# via aiohttp # via aiohttp
aiohttp==3.12.14 aiohttp==3.10.10
# via -r requirements.in # via -r requirements.in
aiosignal==1.4.0 aiosignal==1.3.1
# via aiohttp # via aiohttp
annotated-types==0.7.0 annotated-types==0.7.0
# via pydantic # via pydantic
@@ -29,6 +29,7 @@ filelock==3.16.1
# huggingface-hub # huggingface-hub
# torch # torch
# transformers # transformers
# triton
frozenlist==1.5.0 frozenlist==1.5.0
# via # via
# aiohttp # aiohttp
@@ -110,9 +111,7 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0 prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in # via -r requirements.in
propcache==0.2.0 propcache==0.2.0
# via # via yarl
# aiohttp
# yarl
py-consul==1.5.3 py-consul==1.5.3
# via -r requirements.in # via -r requirements.in
pydantic==2.9.2 pydantic==2.9.2
@@ -156,9 +155,7 @@ triton==3.3.0
# via torch # via torch
typing-extensions==4.12.2 typing-extensions==4.12.2
# via # via
# aiosignal
# anyio # anyio
# exceptiongroup
# fastapi # fastapi
# huggingface-hub # huggingface-hub
# multidict # multidict
@@ -171,5 +168,5 @@ urllib3==2.5.0
# via requests # via requests
uvicorn==0.32.0 uvicorn==0.32.0
# via -r requirements.in # via -r requirements.in
yarl==1.18.3 yarl==1.16.0
# via aiohttp # via aiohttp

View File

@@ -763,7 +763,4 @@ class LegacyConfigMixin(ConfigMixin):
# resolve remapping # resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls) remapped_class = _fetch_remapped_cls_from_config(config, cls)
if remapped_class is cls: return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
else:
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)

View File

@@ -24,7 +24,6 @@ from typing_extensions import Self
from .. import __version__ from .. import __version__
from ..quantizers import DiffusersAutoQuantizer from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import ( from .single_file_utils import (
SingleFileComponentError, SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers, convert_animatediff_checkpoint_to_diffusers,
@@ -431,7 +430,6 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys, unexpected_keys=unexpected_keys,
) )
empty_device_cache()
else: else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

View File

@@ -46,7 +46,6 @@ from ..utils import (
) )
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import empty_device_cache
if is_transformers_available(): if is_transformers_available():
@@ -1690,7 +1689,6 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else: else:
model.load_state_dict(diffusers_format_checkpoint, strict=False) model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2150,7 +2148,6 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else: else:
model.load_state_dict(diffusers_format_checkpoint) model.load_state_dict(diffusers_format_checkpoint)

View File

@@ -18,8 +18,11 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import (
from ..utils.torch_utils import empty_device_cache is_accelerate_available,
is_torch_version,
logging,
)
if is_accelerate_available(): if is_accelerate_available():
@@ -81,7 +84,6 @@ class FluxTransformer2DLoadersMixin:
else: else:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection return image_projection
@@ -156,8 +158,6 @@ class FluxTransformer2DLoadersMixin:
key_id += 1 key_id += 1
empty_device_cache()
return attn_procs return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

View File

@@ -18,7 +18,6 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -81,8 +80,6 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
) )
empty_device_cache()
return attn_procs return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers( def _convert_ip_adapter_image_proj_to_diffusers(
@@ -150,7 +147,6 @@ class SD3Transformer2DLoadersMixin:
else: else:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_proj return image_proj

View File

@@ -43,7 +43,6 @@ from ..utils import (
is_torch_version, is_torch_version,
logging, logging,
) )
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
@@ -754,7 +753,6 @@ class UNet2DConditionLoadersMixin:
else: else:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection return image_projection
@@ -852,8 +850,6 @@ class UNet2DConditionLoadersMixin:
key_id += 2 key_id += 2
empty_device_cache()
return attn_procs return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

View File

@@ -16,10 +16,9 @@
import importlib import importlib
import inspect import inspect
import math
import os import os
from array import array from array import array
from collections import OrderedDict, defaultdict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from zipfile import is_zipfile from zipfile import is_zipfile
@@ -39,7 +38,6 @@ from ..utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_gguf_available, is_gguf_available,
is_torch_available, is_torch_available,
is_torch_version, is_torch_version,
@@ -254,10 +252,6 @@ def load_model_dict_into_meta(
param = param.to(dtype) param = param.to(dtype)
set_module_kwargs["dtype"] = dtype set_module_kwargs["dtype"] = dtype
if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -526,60 +520,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters return parsed_parameters
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)

View File

@@ -62,14 +62,10 @@ from ..utils.hub_utils import (
load_or_create_model_card, load_or_create_model_card,
populate_model_card, populate_model_card,
) )
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import ( from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map, _determine_device_map,
_expand_device_map,
_fetch_index_file, _fetch_index_file,
_fetch_index_file_legacy, _fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model, _load_state_dict_into_model,
load_model_dict_into_meta, load_model_dict_into_meta,
load_state_dict, load_state_dict,
@@ -1473,6 +1469,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# Deal with offload # Deal with offload
if device_map is not None and "disk" in device_map.values(): if device_map is not None and "disk" in device_map.values():
if offload_folder is None: if offload_folder is None:
@@ -1481,27 +1482,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using" " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format." " offers the weights in this format."
) )
else: if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True) os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None: if offload_state_dict is None:
offload_state_dict = True offload_state_dict = True
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict: if offload_state_dict:
state_dict_folder = tempfile.mkdtemp() state_dict_folder = tempfile.mkdtemp()
state_dict_index = {} state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None: if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file # load_state_dict will manage the case where we pass a dict instead of a file
@@ -1511,14 +1503,38 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1: if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
for shard_file in resolved_model_file: for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys( mismatched_keys += _find_mismatched_keys(
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
) )
if low_cpu_mem_usage: if low_cpu_mem_usage:
@@ -1538,9 +1554,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: else:
if assign_to_params_buffers is None: if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
empty_device_cache() error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
if offload_index is not None and len(offload_index) > 0: if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder) save_offload_index(offload_index, offload_folder)
@@ -1877,9 +1892,4 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping # resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls) remapped_class = _fetch_remapped_cls_from_config(config, cls)
if remapped_class is cls: return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
return super(LegacyModelMixin, remapped_class).from_pretrained(
pretrained_model_name_or_path, **kwargs_copy
)
else:
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)

View File

@@ -187,15 +187,9 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA # 4. Prepare for GQA
if torch.onnx.is_in_onnx_export(): query_idx = torch.tensor(query.size(3), device=query.device)
query_idx = torch.tensor(query.size(3), device=query.device) key_idx = torch.tensor(key.size(3), device=key.device)
key_idx = torch.tensor(key.size(3), device=key.device) value_idx = torch.tensor(value.size(3), device=value.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3) key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3)

View File

@@ -490,7 +490,6 @@ class FluxTransformer2DModel(
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
joint_attention_kwargs,
) )
else: else:
@@ -522,7 +521,6 @@ class FluxTransformer2DModel(
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
joint_attention_kwargs,
) )
else: else:

View File

@@ -479,22 +479,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return list(combined_dict.values()) return list(combined_dict.values())
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs]
@property
def intermediate_input_names(self) -> List[str]:
return [input_param.name for input_param in self.intermediate_inputs]
@property
def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs]
@property
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]
class PipelineBlock(ModularPipelineBlocks): class PipelineBlock(ModularPipelineBlocks):
""" """
@@ -2841,8 +2825,3 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
type_hint=type_hint, type_hint=type_hint,
**spec_dict, **spec_dict,
) )
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)

View File

@@ -744,6 +744,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=None, timestep=None,
is_strength_max=True, is_strength_max=True,
add_noise=True, add_noise=True,
return_noise=False,
return_image_latents=False,
): ):
shape = ( shape = (
batch_size, batch_size,
@@ -766,7 +768,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
if image.shape[1] == 4: if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype) image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif latents is None and not is_strength_max: elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
@@ -784,7 +786,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device) latents = image_latents.to(device)
outputs = (latents, noise, image_latents) outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs return outputs
@@ -856,7 +864,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( block_state.latents, block_state.noise = self.prepare_latents_inpaint(
components, components,
block_state.batch_size * block_state.num_images_per_prompt, block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents, components.num_channels_latents,
@@ -870,6 +878,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=block_state.latent_timestep, timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max, is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise, add_noise=block_state.add_noise,
return_noise=True,
return_image_latents=False,
) )
# 7. Prepare mask latent variables # 7. Prepare mask latent variables

View File

@@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
CLIPImageProcessor, CLIPImageProcessor,
CLIPTextModel, CLIPTextModel,
@@ -37,13 +38,7 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import ( from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
AutoencoderKL,
ControlNetUnionModel,
ImageProjection,
MultiControlNetUnionModel,
UNet2DConditionModel,
)
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
@@ -267,9 +262,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
controlnet: Union[ controlnet: ControlNetUnionModel,
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False, requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
@@ -279,8 +272,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
): ):
super().__init__() super().__init__()
if isinstance(controlnet, (list, tuple)): if not isinstance(controlnet, ControlNetUnionModel):
controlnet = MultiControlNetUnionModel(controlnet) raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
self.register_modules( self.register_modules(
vae=vae, vae=vae,
@@ -656,7 +649,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
control_mode=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
@@ -730,44 +722,28 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetUnionModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image` # Check `image`
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
if isinstance(controlnet, ControlNetUnionModel): )
for image_ in image: if (
self.check_image(image_, prompt, prompt_embeds) isinstance(self.controlnet, ControlNetModel)
elif isinstance(controlnet, MultiControlNetUnionModel): or is_compiled
if not isinstance(image, list): and isinstance(self.controlnet._orig_mod, ControlNetModel)
raise TypeError("For multiple controlnets: `image` must be type `list`") ):
elif not all(isinstance(i, list) for i in image): self.check_image(image, prompt, prompt_embeds)
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") elif (
elif len(image) != len(self.controlnet.nets): isinstance(self.controlnet, ControlNetUnionModel)
raise ValueError( or is_compiled
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
) ):
self.check_image(image, prompt, prompt_embeds)
for images_ in image: else:
for image_ in images_: assert False
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)): if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start] control_guidance_start = [control_guidance_start]
if isinstance(controlnet, MultiControlNetUnionModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
if not isinstance(control_guidance_end, (tuple, list)): if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end] control_guidance_end = [control_guidance_end]
@@ -786,15 +762,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if end > 1.0: if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Check `control_mode`
if isinstance(controlnet, ControlNetUnionModel):
if max(control_mode) >= controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
elif isinstance(controlnet, MultiControlNetUnionModel):
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1082,7 +1049,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None, control_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 0.8, strength: float = 0.8,
@@ -1107,7 +1074,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None, control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None, target_size: Tuple[int, int] = None,
@@ -1137,13 +1104,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again. image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): control_image (`PipelineImageInput`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
images must be passed as a list such that each element of the list can be correctly batched for input init, images must be passed as a list such that each element of the list can be correctly batched for
to a single ControlNet. input to a single controlnet.
height (`int`, *optional*, defaults to the size of control_image): height (`int`, *optional*, defaults to the size of control_image):
The height in pixels of the generated image. Anything below 512 pixels won't work well for The height in pixels of the generated image. Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1217,21 +1184,16 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
the corresponding scale as a list. corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`): guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying. The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying. The percentage of total steps at which the controlnet stops applying.
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
where each ControlNet should have its corresponding control mode list. Should reflect the order of
conditions in control_image
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1311,6 +1273,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list): if not isinstance(control_image, list):
control_image = [control_image] control_image = [control_image]
else: else:
@@ -1319,56 +1287,37 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if not isinstance(control_mode, list): if not isinstance(control_mode, list):
control_mode = [control_mode] control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel): if len(control_image) != len(control_mode):
control_image = [[item] for item in control_image] raise ValueError("Expected len(control_image) == len(control_type)")
control_mode = [[item] for item in control_mode]
# align format for control guidance num_control_type = controlnet.config.num_control_type
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs # 1. Check inputs
self.check_inputs( control_type = [0 for _ in range(num_control_type)]
prompt, for _image, control_idx in zip(control_image, control_mode):
prompt_2, control_type[control_idx] = 1
control_image, self.check_inputs(
strength, prompt,
num_inference_steps, prompt_2,
callback_steps, _image,
negative_prompt, strength,
negative_prompt_2, num_inference_steps,
prompt_embeds, callback_steps,
negative_prompt_embeds, negative_prompt,
pooled_prompt_embeds, negative_prompt_2,
negative_pooled_prompt_embeds, prompt_embeds,
ip_adapter_image, negative_prompt_embeds,
ip_adapter_image_embeds, pooled_prompt_embeds,
controlnet_conditioning_scale, negative_pooled_prompt_embeds,
control_guidance_start, ip_adapter_image,
control_guidance_end, ip_adapter_image_embeds,
control_mode, controlnet_conditioning_scale,
callback_on_step_end_tensor_inputs, control_guidance_start,
) control_guidance_end,
callback_on_step_end_tensor_inputs,
)
if isinstance(controlnet, ControlNetUnionModel): control_type = torch.Tensor(control_type)
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
]
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
@@ -1385,11 +1334,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
device = self._execution_device device = self._execution_device
global_pool_conditions = ( global_pool_conditions = controlnet.config.global_pool_conditions
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetUnionModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3.1. Encode input prompt # 3.1. Encode input prompt
@@ -1427,55 +1372,22 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
self.do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 4.1 Prepare image # 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
# 4.2 Prepare control images for idx, _ in enumerate(control_image):
if isinstance(controlnet, ControlNetUnionModel): control_image[idx] = self.prepare_control_image(
control_images = [] image=control_image[idx],
width=width,
for image_ in control_image: height=height,
image_ = self.prepare_control_image( batch_size=batch_size * num_images_per_prompt,
image=image_, num_images_per_prompt=num_images_per_prompt,
width=width, device=device,
height=height, dtype=controlnet.dtype,
batch_size=batch_size * num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt, guess_mode=guess_mode,
device=device, )
dtype=controlnet.dtype, height, width = control_image[idx].shape[-2:]
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
control_images.append(image_)
control_image = control_images
height, width = control_image[0].shape[-2:]
elif isinstance(controlnet, MultiControlNetUnionModel):
control_images = []
for control_image_ in control_image:
images = []
for image_ in control_image_:
image_ = self.prepare_control_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
control_images.append(images)
control_image = control_images
height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1502,11 +1414,10 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 7.1 Create tensor stating which controlnets to keep # 7.1 Create tensor stating which controlnets to keep
controlnet_keep = [] controlnet_keep = []
for i in range(len(timesteps)): for i in range(len(timesteps)):
keeps = [ controlnet_keep.append(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 1.0
for s, e in zip(control_guidance_start, control_guidance_end) - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
] )
controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width) original_size = original_size or (height, width)
@@ -1549,25 +1460,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = prompt_embeds.to(device) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device) add_time_ids = add_time_ids.to(device)
control_type = (
control_type_repeat_factor = ( control_type.reshape(1, -1)
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) .to(device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
) )
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
for _control_type in control_type
]
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:

View File

@@ -383,8 +383,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler latents = latents * np.float64(self.scheduler.init_noise_sigma)
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

View File

@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * np.float64(self.scheduler.init_noise_sigma)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

View File

@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# Scale the initial noise by the standard deviation required by the scheduler # Scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * np.float64(self.scheduler.init_noise_sigma)
# 5. Add noise to image # 5. Add noise to image
noise_level = np.array([noise_level]).astype(np.int64) noise_level = np.array([noise_level]).astype(np.int64)

View File

@@ -184,14 +184,5 @@ def get_device():
def empty_device_cache(device_type: Optional[str] = None): def empty_device_cache(device_type: Optional[str] = None):
if device_type is None: if device_type is None:
device_type = get_device() device_type = get_device()
if device_type in ["cpu"]:
return
device_mod = getattr(torch, device_type, torch.cuda) device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache() device_mod.empty_cache()
def device_synchronize(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()

View File

@@ -1,511 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from diffusers import (
ClassifierFreeGuidance,
ComponentsManager,
ModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers.loaders import ModularIPAdapterMixin
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
from ...models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ..test_modular_pipelines_common import (
ModularPipelineTesterMixin,
)
enable_full_determinism()
class SDXLModularTests:
"""
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
"""
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset(
[
"prompt",
"height",
"width",
"negative_prompt",
"cross_attention_kwargs",
"image",
"mask_image",
]
)
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_default_components(torch_dtype=torch_dtype)
return pipeline
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
assert image.shape == expected_image_shape
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
"Image Slice does not match expected slice"
)
class SDXLModularIPAdapterTests:
"""
This mixin is designed to test IP Adapter.
"""
def test_pipeline_inputs_and_blocks(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
assert "ip_adapter_image" in parameters, (
"`ip_adapter_image` argument must be supported by the `__call__` method"
)
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
intermediate_parameters = blocks.intermediate_input_names
assert "ip_adapter_image" not in parameters, (
"`ip_adapter_image` argument must be removed from the `__call__` method"
)
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
if "image" in parameters and "strength" in parameters:
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs, output="images")
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
assert max_diff_without_adapter_scale < expected_max_diff, (
"Output without ip-adapter must be same as normal inference"
)
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
"Output without multi-ip-adapter must be same as normal inference"
)
assert max_diff_with_multi_adapter_scale > 1e-2, (
"Output with multi-ip-adapter scale must be different from normal inference"
)
class SDXLModularControlNetTests:
"""
This mixin is designed to test ControlNet.
"""
def test_pipeline_inputs(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
assert "controlnet_conditioning_scale" in parameters, (
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
)
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
controlnet_embedder_scale_factor = 2
image = torch.randn(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
device=torch_device,
)
inputs["control_image"] = image
return inputs
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for ControlNet.
The following scenarios are tested:
- Single ControlNet with scale=0 should produce same output as no ControlNet.
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass without controlnet
inputs = self.get_dummy_inputs(torch_device)
output_without_controlnet = pipe(**inputs, output="images")
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but scale=0 which should have no effect
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 0.0
output_without_controlnet_scale = pipe(**inputs, output="images")
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but with scale of adapter weights
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 42.0
output_with_controlnet_scale = pipe(**inputs, output="images")
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
assert max_diff_without_controlnet_scale < expected_max_diff, (
"Output without controlnet must be same as normal inference"
)
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
def test_controlnet_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularGuiderTests:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
sd_pipe = self.get_pipeline(components_manager=cm)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_save_from_pretrained(self):
pipes = []
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname)
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
sd_pipe.load_default_components(torch_dtype=torch.float32)
sd_pipe.to(torch_device)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
class SDXLImg2ImgModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
class SDXLInpaintingModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
inputs["image"] = init_image
inputs["mask_image"] = mask_image
inputs["strength"] = 1.0
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.40872607,
0.38842705,
0.34893104,
0.47837183,
0.43792963,
0.5332134,
0.3716843,
0.47274873,
0.45000193,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

View File

@@ -1,330 +0,0 @@
import gc
import unittest
from typing import Callable, Union
import numpy as np
import torch
import diffusers
from diffusers.utils import logging
from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
@require_torch
class ModularPipelineTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each modular pipeline,
including:
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
- test_to_device: check if the pipeline's __call__ method can handle different devices
"""
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(
[
"generator",
]
)
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator
@property
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def repo(self) -> str:
raise NotImplementedError(
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
"See existing pipeline tests for reference."
)
def get_pipeline(self):
raise NotImplementedError(
"You need to implement `get_pipeline(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
"with non-configurable height and width arguments should set the attribute as "
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
"See existing pipeline tests for reference."
)
@property
def batch_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `batch_params` in the child test class. "
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
"image pipeline `negative_prompt` is not batched should set the attribute as "
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
"See existing pipeline tests for reference."
)
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_pipeline_call_signature(self):
pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names
intermediate_parameters = pipe.blocks.intermediate_input_names
optional_parameters = pipe.default_call_parameters
def _check_for_parameters(parameters, expected_parameters, param_type):
remaining_parameters = {param for param in parameters if param not in expected_parameters}
assert (
len(remaining_parameters) == 0
), f"Required {param_type} parameters not present: {remaining_parameters}"
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# prepare batched inputs
batched_inputs = []
for batch_size in batch_sizes:
batched_input = {}
batched_input.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_input[name] = batch_size * [value]
if batch_generator and "generator" in inputs:
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_input["batch_size"] = batch_size
batched_inputs.append(batched_input)
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(
device == torch_device for device in model_devices
), "All pipeline components are not on accelerator device"
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
if "num_images_per_prompt" not in pipe.blocks.input_names:
return
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
@require_accelerator
def test_components_auto_cpu_offload(self):
base_pipe = self.get_pipeline().to(torch_device)
for component in base_pipe.components:
assert component.device == torch_device
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
offload_pipe = self.get_pipeline(components_manager=cm)

View File

@@ -155,7 +155,7 @@ class FluxPipelineFastTests(
# Outputs should be different here # Outputs should be different here
# For some reasons, they don't show large differences # For some reasons, they don't show large differences
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.") assert max_diff > 1e-6
def test_fused_qkv_projections(self): def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -187,17 +187,14 @@ class FluxPipelineFastTests(
image = pipe(**inputs).images image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1] image_slice_disabled = image[0, -3:, -3:, -1]
self.assertTrue( assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), "Fusion of QKV projections shouldn't affect the outputs."
("Fusion of QKV projections shouldn't affect the outputs."),
) )
self.assertTrue( assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
) )
self.assertTrue( assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), "Original outputs should match when fused QKV projections are disabled."
("Original outputs should match when fused QKV projections are disabled."),
) )
def test_flux_image_output_shape(self): def test_flux_image_output_shape(self):
@@ -212,11 +209,7 @@ class FluxPipelineFastTests(
inputs.update({"height": height, "width": width}) inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0] image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape output_height, output_width, _ = image.shape
self.assertEqual( assert (output_height, output_width) == (expected_height, expected_width)
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_flux_true_cfg(self): def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -227,9 +220,7 @@ class FluxPipelineFastTests(
inputs["negative_prompt"] = "bad quality" inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0 inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
self.assertFalse( assert not np.allclose(no_true_cfg_out, true_cfg_out)
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
)
@nightly @nightly
@@ -278,17 +269,45 @@ class FluxPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0] image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10] image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array( expected_slice = np.array(
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203], [
0.3242,
0.3203,
0.3164,
0.3164,
0.3125,
0.3125,
0.3281,
0.3242,
0.3203,
0.3301,
0.3262,
0.3242,
0.3281,
0.3242,
0.3203,
0.3262,
0.3262,
0.3164,
0.3262,
0.3281,
0.3184,
0.3281,
0.3281,
0.3203,
0.3281,
0.3281,
0.3164,
0.3320,
0.3320,
0.3203,
],
dtype=np.float32, dtype=np.float32,
) )
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
self.assertLess(
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}" assert max_diff < 1e-4
)
@slow @slow
@@ -358,14 +377,42 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0] image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10] image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array( expected_slice = np.array(
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543], [
0.1855,
0.1680,
0.1406,
0.1953,
0.1699,
0.1465,
0.2012,
0.1738,
0.1484,
0.2051,
0.1797,
0.1523,
0.2012,
0.1719,
0.1445,
0.2070,
0.1777,
0.1465,
0.2090,
0.1836,
0.1484,
0.2129,
0.1875,
0.1523,
0.2090,
0.1816,
0.1484,
0.2110,
0.1836,
0.1543,
],
dtype=np.float32, dtype=np.float32,
) )
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
self.assertLess(
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}" assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
)

View File

@@ -20,6 +20,12 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
] ]
) )
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
IMAGE_VARIATION_PARAMS = frozenset( IMAGE_VARIATION_PARAMS = frozenset(
[ [
"image", "image",
@@ -29,6 +35,8 @@ IMAGE_VARIATION_PARAMS = frozenset(
] ]
) )
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset( TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[ [
"prompt", "prompt",
@@ -42,6 +50,8 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
] ]
) )
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[ [
# Text guided image variation with an image mask # Text guided image variation with an image mask
@@ -57,6 +67,8 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
] ]
) )
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_PARAMS = frozenset( IMAGE_INPAINTING_PARAMS = frozenset(
[ [
# image variation with an image mask # image variation with an image mask
@@ -68,6 +80,8 @@ IMAGE_INPAINTING_PARAMS = frozenset(
] ]
) )
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[ [
"example_image", "example_image",
@@ -79,12 +93,20 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
] ]
) )
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"]) CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"]) CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_PARAMS = frozenset( TEXT_TO_AUDIO_PARAMS = frozenset(
[ [
"prompt", "prompt",
@@ -97,38 +119,11 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
] ]
) )
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
# image params
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
# batch params
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
# callback params
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])

View File

@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1") @require_torch_version_greater("2.7.1")
@require_bitsandbytes_version_greater("0.45.5") @require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase): class Bnb4BitCompileTests(QuantCompileTests):
@property @property
def quantization_config(self): def quantization_config(self):
return PipelineQuantizationConfig( return PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit", quant_backend="bitsandbytes_8bit",
quant_kwargs={ quant_kwargs={
"load_in_4bit": True, "load_in_4bit": True,
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
@@ -888,7 +888,12 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
def test_torch_compile(self): def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
super().test_torch_compile() super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self): def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(use_stream=True) super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)

View File

@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0") @require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5") @require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): class Bnb8BitCompileTests(QuantCompileTests):
@property @property
def quantization_config(self): def quantization_config(self):
return PipelineQuantizationConfig( return PipelineQuantizationConfig(
@@ -849,11 +849,15 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
def test_torch_compile(self): def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(torch_dtype=torch.float16) super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self): def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16) super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self): def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True) super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
)

View File

@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
@require_torch_version_greater("2.7.1") @require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests, unittest.TestCase): class GGUFCompileTests(QuantCompileTests):
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
@@ -662,6 +662,15 @@ class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
def quantization_config(self): def quantization_config(self):
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype) return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
def _init_pipeline(self, *args, **kwargs): def _init_pipeline(self, *args, **kwargs):
transformer = FluxTransformer2DModel.from_single_file( transformer = FluxTransformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc import gc
import inspect import unittest
import torch import torch
@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
@require_torch_gpu @require_torch_gpu
@slow @slow
class QuantCompileTests: class QuantCompileTests(unittest.TestCase):
@property @property
def quantization_config(self): def quantization_config(self):
raise NotImplementedError( raise NotImplementedError(
@@ -50,26 +50,30 @@ class QuantCompileTests:
) )
return pipe return pipe
def _test_torch_compile(self, torch_dtype=torch.bfloat16): def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda") pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# `fullgraph=True` ensures no graph breaks # import to ensure fullgraph True
pipe.transformer.compile(fullgraph=True) pipe.transformer.compile(fullgraph=True)
# small resolutions to ensure speedy execution. for _ in range(2):
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) # small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16): def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
pipe.transformer.compile() pipe.transformer.compile()
# small resolutions to ensure speedy execution. for _ in range(2):
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) # small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False): def _test_torch_compile_with_group_offload_leaf(
torch._dynamo.config.cache_size_limit = 1000 self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
torch._dynamo.config.cache_size_limit = 10000
pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe = self._init_pipeline(quantization_config, torch_dtype)
group_offload_kwargs = { group_offload_kwargs = {
"onload_device": torch.device("cuda"), "onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"), "offload_device": torch.device("cpu"),
@@ -83,17 +87,6 @@ class QuantCompileTests:
if torch.device(component.device).type == "cpu": if torch.device(component.device).type == "cpu":
component.to("cuda") component.to("cuda")
# small resolutions to ensure speedy execution. for _ in range(2):
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) # small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def test_torch_compile(self):
self._test_torch_compile()
def test_torch_compile_with_cpu_offload(self):
self._test_torch_compile_with_cpu_offload()
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
for cls in inspect.getmro(self.__class__):
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
return
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)

View File

@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase):
@require_torchao_version_greater_or_equal("0.7.0") @require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase): class TorchAoCompileTest(QuantCompileTests):
@property @property
def quantization_config(self): def quantization_config(self):
return PipelineQuantizationConfig( return PipelineQuantizationConfig(
@@ -639,15 +639,17 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
}, },
) )
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
@unittest.skip( @unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling." "when compiling."
) )
def test_torch_compile_with_cpu_offload(self): def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight # RuntimeError: _apply(): Couldn't swap Linear.weight
super().test_torch_compile_with_cpu_offload() super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
@parameterized.expand([False, True])
@unittest.skip( @unittest.skip(
""" """
For `use_stream=False`: For `use_stream=False`:
@@ -657,7 +659,8 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO. Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
""" """
) )
def test_torch_compile_with_group_offload_leaf(self, use_stream): @parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
# For use_stream=False: # For use_stream=False:
# If we run group offloading without compilation, we will see: # If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match. # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
@@ -670,7 +673,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# For use_stream=True: # For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={} # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream) super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners