mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
28 Commits
cnet-union
...
modular-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3aabef5de4 | ||
|
|
39be374591 | ||
|
|
54e17f3084 | ||
|
|
80702d222d | ||
|
|
625cc8ede8 | ||
|
|
a2a9e4eadb | ||
|
|
0998bd75ad | ||
|
|
5f560d05a2 | ||
|
|
4b7a9e9fa9 | ||
|
|
d8fa2de36f | ||
|
|
4df2739a5e | ||
|
|
d92855ddf0 | ||
|
|
0a5c90ed47 | ||
|
|
aa14f090f8 | ||
|
|
c5d6e0b537 | ||
|
|
39831599f1 | ||
|
|
b73c738392 | ||
|
|
06fd427797 | ||
|
|
48a551251d | ||
|
|
0fa58127f8 | ||
|
|
b165cf3742 | ||
|
|
6398fbc391 | ||
|
|
3c8b67b371 | ||
|
|
9feb946432 | ||
|
|
c90352754a | ||
|
|
7a935a0bbe | ||
|
|
941b7fc084 | ||
|
|
76a62ac9cc |
141
.github/workflows/pr_modular_tests.yml
vendored
Normal file
141
.github/workflows/pr_modular_tests.yml
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/diffusers/modular_pipelines/**.py"
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/modular_pipelines/**.py"
|
||||
- ".github/**.yml"
|
||||
- "utils/**.py"
|
||||
- "setup.py"
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -174,39 +174,36 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
|
||||
|
||||
### Regional compilation
|
||||
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
|
||||
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.
|
||||
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
|
||||
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
|
||||
|
||||
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
|
||||
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
|
||||
|
||||
```py
|
||||
# pip install -U diffusers
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
# Compile only the repeated Transformer layers inside the UNet
|
||||
pipe.unet.compile_repeated_blocks(fullgraph=True)
|
||||
# compile only the repeated transformer layers inside the UNet
|
||||
pipeline.unet.compile_repeated_blocks(fullgraph=True)
|
||||
```
|
||||
|
||||
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
|
||||
|
||||
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
|
||||
|
||||
```py
|
||||
class MyUNet(ModelMixin):
|
||||
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
|
||||
```
|
||||
|
||||
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
|
||||
|
||||
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
|
||||
|
||||
> [!TIP]
|
||||
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
|
||||
|
||||
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
|
||||
|
||||
```py
|
||||
# pip install -U accelerate
|
||||
@@ -219,8 +216,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
|
||||
|
||||
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
|
||||
|
||||
### Graph breaks
|
||||
|
||||
@@ -296,3 +293,9 @@ An input is projected into three subspaces, represented by the projection matric
|
||||
```py
|
||||
pipeline.fuse_qkv_projections()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
|
||||
|
||||
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
|
||||
@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
|
||||
|
||||
> [!TIP]
|
||||
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
|
||||
|
||||
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
|
||||
|
||||
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
|
||||
@@ -25,7 +28,7 @@ The table below provides a comparison of optimization strategy combinations and
|
||||
| quantization | 32.602 | 14.9453 |
|
||||
| quantization, torch.compile | 25.847 | 14.9448 |
|
||||
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
|
||||
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
|
||||
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
|
||||
|
||||
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
|
||||
|
||||
|
||||
@@ -1330,7 +1330,7 @@ def main(args):
|
||||
# controlnet(s) inference
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
|
||||
controlnet_image = controlnet_image * vae.config.scaling_factor
|
||||
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
control_block_res_samples = controlnet(
|
||||
hidden_states=noisy_model_input,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in -o requirements.txt
|
||||
aiohappyeyeballs==2.4.3
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.10.10
|
||||
aiohttp==3.12.14
|
||||
# via -r requirements.in
|
||||
aiosignal==1.3.1
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
@@ -29,7 +29,6 @@ filelock==3.16.1
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
# triton
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
|
||||
prometheus-fastapi-instrumentator==7.0.0
|
||||
# via -r requirements.in
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
py-consul==1.5.3
|
||||
# via -r requirements.in
|
||||
pydantic==2.9.2
|
||||
@@ -155,7 +156,9 @@ triton==3.3.0
|
||||
# via torch
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# multidict
|
||||
@@ -168,5 +171,5 @@ urllib3==2.5.0
|
||||
# via requests
|
||||
uvicorn==0.32.0
|
||||
# via -r requirements.in
|
||||
yarl==1.16.0
|
||||
yarl==1.18.3
|
||||
# via aiohttp
|
||||
|
||||
@@ -763,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin):
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
if remapped_class is cls:
|
||||
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
|
||||
else:
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing_extensions import Self
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
@@ -430,6 +431,7 @@ class FromOriginalModelMixin:
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
empty_device_cache()
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1689,6 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
empty_device_cache()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
@@ -2148,6 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
empty_device_cache()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
|
||||
@@ -18,11 +18,8 @@ from ..models.embeddings import (
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -84,6 +81,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -158,6 +156,8 @@ class FluxTransformer2DLoadersMixin:
|
||||
|
||||
key_id += 1
|
||||
|
||||
empty_device_cache()
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -80,6 +81,8 @@ class SD3Transformer2DLoadersMixin:
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
empty_device_cache()
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||
@@ -147,6 +150,7 @@ class SD3Transformer2DLoadersMixin:
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
|
||||
return image_proj
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
@@ -753,6 +754,7 @@ class UNet2DConditionLoadersMixin:
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -850,6 +852,8 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
key_id += 2
|
||||
|
||||
empty_device_cache()
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
|
||||
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
condition = self.controlnet_cond_embedding(cond)
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[control_idx]
|
||||
if from_multi:
|
||||
if from_multi or len(control_type_idx) == 1:
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
else:
|
||||
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
if from_multi:
|
||||
if from_multi or len(control_type_idx) == 1:
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
else:
|
||||
controlnet_cond_fuser += condition + alpha * scale
|
||||
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
if from_multi:
|
||||
if from_multi or len(control_type_idx) == 1:
|
||||
scales = scales * conditioning_scale[0]
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
elif from_multi:
|
||||
elif from_multi or len(control_type_idx) == 1:
|
||||
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
@@ -38,6 +39,7 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
@@ -252,6 +254,10 @@ def load_model_dict_into_meta(
|
||||
param = param.to(dtype)
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
|
||||
if is_accelerate_version(">", "1.8.1"):
|
||||
set_module_kwargs["non_blocking"] = True
|
||||
set_module_kwargs["clear_cache"] = False
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
||||
|
||||
return parsed_parameters
|
||||
|
||||
|
||||
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
|
||||
mismatched_keys = []
|
||||
if not ignore_mismatched_sizes:
|
||||
return mismatched_keys
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
|
||||
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
|
||||
def _expand_device_map(device_map, param_names):
|
||||
"""
|
||||
Expand a device map to return the correspondence parameter name to device.
|
||||
"""
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
||||
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
|
||||
"""
|
||||
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
|
||||
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
|
||||
very large margin.
|
||||
"""
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device)
|
||||
for param, device in expanded_device_map.items()
|
||||
if str(device) not in ["cpu", "disk"]
|
||||
}
|
||||
parameter_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += math.prod(param.shape)
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, param_count in parameter_count.items():
|
||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
@@ -62,10 +62,14 @@ from ..utils.hub_utils import (
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
_expand_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_find_mismatched_keys,
|
||||
_load_state_dict_into_model,
|
||||
load_model_dict_into_meta,
|
||||
load_state_dict,
|
||||
@@ -1469,11 +1473,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
mismatched_keys = []
|
||||
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
@@ -1482,18 +1481,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if offload_folder is not None:
|
||||
else:
|
||||
os.makedirs(offload_folder, exist_ok=True)
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
|
||||
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
|
||||
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
|
||||
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
|
||||
# tensors using their expected shape and not performing any initialization of the memory (empty data).
|
||||
# When the actual device allocations happen, the allocator already has a pool of unused device memory
|
||||
# that it can re-use for faster loading of the model.
|
||||
# TODO: add support for warmup with hf_quantizer
|
||||
if device_map is not None and hf_quantizer is None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, dtype)
|
||||
|
||||
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
||||
state_dict_folder, state_dict_index = None, None
|
||||
if offload_state_dict:
|
||||
state_dict_folder = tempfile.mkdtemp()
|
||||
state_dict_index = {}
|
||||
else:
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
|
||||
if state_dict is not None:
|
||||
# load_state_dict will manage the case where we pass a dict instead of a file
|
||||
@@ -1503,38 +1511,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if len(resolved_model_file) > 1:
|
||||
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
|
||||
mismatched_keys = []
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
for shard_file in resolved_model_file:
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -1554,9 +1538,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
if assign_to_params_buffers is None:
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
||||
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
empty_device_cache()
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
offload_index = None
|
||||
@@ -1892,4 +1877,9 @@ class LegacyModelMixin(ModelMixin):
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
||||
if remapped_class is cls:
|
||||
return super(LegacyModelMixin, remapped_class).from_pretrained(
|
||||
pretrained_model_name_or_path, **kwargs_copy
|
||||
)
|
||||
else:
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
||||
|
||||
@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
|
||||
# 4. Prepare for GQA
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
|
||||
else:
|
||||
query_idx = query.size(3)
|
||||
key_idx = key.size(3)
|
||||
value_idx = value.size(3)
|
||||
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
|
||||
@@ -490,6 +490,7 @@ class FluxTransformer2DModel(
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -521,6 +522,7 @@ class FluxTransformer2DModel(
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -479,6 +479,22 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs]
|
||||
|
||||
@property
|
||||
def intermediate_input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.intermediate_inputs]
|
||||
|
||||
@property
|
||||
def intermediate_output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.intermediate_outputs]
|
||||
|
||||
@property
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
"""
|
||||
@@ -2825,3 +2841,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
type_hint=type_hint,
|
||||
**spec_dict,
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
@@ -744,8 +744,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +766,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +784,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +856,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +870,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -35,7 +34,13 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetUnionModel,
|
||||
ImageProjection,
|
||||
MultiControlNetUnionModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
@@ -230,7 +235,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetUnionModel,
|
||||
controlnet: Union[
|
||||
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
@@ -240,8 +247,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(controlnet, ControlNetUnionModel):
|
||||
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetUnionModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -660,6 +667,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
control_mode=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
@@ -747,25 +755,34 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
else:
|
||||
assert False
|
||||
# Check `image`
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
elif not all(isinstance(i, list) for i in image):
|
||||
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
@@ -778,6 +795,12 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
||||
)
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if len(control_guidance_start) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
||||
)
|
||||
|
||||
for start, end in zip(control_guidance_start, control_guidance_end):
|
||||
if start >= end:
|
||||
raise ValueError(
|
||||
@@ -788,6 +811,28 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Check `control_mode`
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if max(control_mode) >= controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
|
||||
# Equal number of `image` and `control_mode` elements
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if len(image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_mode)")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not all(isinstance(i, list) for i in control_mode):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
|
||||
)
|
||||
|
||||
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_mode)")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
@@ -1117,7 +1162,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
@@ -1145,7 +1190,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
@@ -1177,6 +1222,13 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
@@ -1269,6 +1321,22 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
||||
the corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
||||
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
|
||||
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
|
||||
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
|
||||
where each ControlNet should have its corresponding control mode list. Should reflect the order of
|
||||
conditions in control_image.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
||||
@@ -1333,22 +1401,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
# # 0.0 Default height and width to unet
|
||||
# height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
# width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 0.1 align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1357,40 +1409,59 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if len(control_image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
|
||||
num_control_type = controlnet.config.num_control_type
|
||||
|
||||
# 1. Check inputs
|
||||
control_type = [0 for _ in range(num_control_type)]
|
||||
for _image, control_idx in zip(control_image, control_mode):
|
||||
control_type[control_idx] = 1
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
_image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
control_type = torch.Tensor(control_type)
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
control_mode,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
|
||||
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
|
||||
]
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
@@ -1483,21 +1554,55 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 5.2 Prepare control images
|
||||
for idx, _ in enumerate(control_image):
|
||||
control_image[idx] = self.prepare_control_image(
|
||||
image=control_image[idx],
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = control_image[idx].shape[-2:]
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for image_ in control_image:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
control_images.append(image_)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0].shape[-2:]
|
||||
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
images = []
|
||||
|
||||
for image_ in control_image_:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
control_images.append(images)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
|
||||
# 5.3 Prepare mask
|
||||
mask = self.mask_processor.preprocess(
|
||||
@@ -1559,10 +1664,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
# 8.2 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
controlnet_keep.append(
|
||||
1.0
|
||||
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
||||
)
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps)
|
||||
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
height, width = latents.shape[-2:]
|
||||
@@ -1627,11 +1733,24 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
|
||||
@@ -1452,17 +1452,21 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
)
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -38,7 +37,13 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetUnionModel,
|
||||
ImageProjection,
|
||||
MultiControlNetUnionModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
@@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetUnionModel,
|
||||
controlnet: Union[
|
||||
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
@@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(controlnet, ControlNetUnionModel):
|
||||
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetUnionModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
control_mode=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
@@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
elif not all(isinstance(i, list) for i in image):
|
||||
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if len(control_guidance_start) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
||||
)
|
||||
|
||||
if not isinstance(control_guidance_end, (tuple, list)):
|
||||
control_guidance_end = [control_guidance_end]
|
||||
|
||||
@@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Check `control_mode`
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if max(control_mode) >= controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
@@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
@@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
@@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The initial image will be used as the starting point for the image generation process. Can also accept
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
control_image (`PipelineImageInput`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
|
||||
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
||||
and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
|
||||
init, images must be passed as a list such that each element of the list can be correctly batched for
|
||||
input to a single controlnet.
|
||||
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
height (`int`, *optional*, defaults to the size of control_image):
|
||||
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
@@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
||||
the corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the controlnet starts applying.
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the controlnet stops applying.
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
|
||||
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
|
||||
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
|
||||
where each ControlNet should have its corresponding control mode list. Should reflect the order of
|
||||
conditions in control_image
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
@@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if len(control_image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
|
||||
num_control_type = controlnet.config.num_control_type
|
||||
|
||||
# 1. Check inputs
|
||||
control_type = [0 for _ in range(num_control_type)]
|
||||
for _image, control_idx in zip(control_image, control_mode):
|
||||
control_type[control_idx] = 1
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
control_type = torch.Tensor(control_type)
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
control_mode,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
|
||||
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
|
||||
]
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
@@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
global_pool_conditions = controlnet.config.global_pool_conditions
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetUnionModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3.1. Encode input prompt
|
||||
@@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 4. Prepare image and controlnet_conditioning_image
|
||||
# 4.1 Prepare image
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
for idx, _ in enumerate(control_image):
|
||||
control_image[idx] = self.prepare_control_image(
|
||||
image=control_image[idx],
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = control_image[idx].shape[-2:]
|
||||
# 4.2 Prepare control images
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for image_ in control_image:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
control_images.append(image_)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0].shape[-2:]
|
||||
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
images = []
|
||||
|
||||
for image_ in control_image_:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
control_images.append(images)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
controlnet_keep.append(
|
||||
1.0
|
||||
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
||||
)
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
original_size = original_size or (height, width)
|
||||
@@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = np.array([noise_level]).astype(np.int64)
|
||||
|
||||
@@ -184,5 +184,14 @@ def get_device():
|
||||
def empty_device_cache(device_type: Optional[str] = None):
|
||||
if device_type is None:
|
||||
device_type = get_device()
|
||||
if device_type in ["cpu"]:
|
||||
return
|
||||
device_mod = getattr(torch, device_type, torch.cuda)
|
||||
device_mod.empty_cache()
|
||||
|
||||
|
||||
def device_synchronize(device_type: Optional[str] = None):
|
||||
if device_type is None:
|
||||
device_type = get_device()
|
||||
device_mod = getattr(torch, device_type, torch.cuda)
|
||||
device_mod.synchronize()
|
||||
|
||||
0
tests/modular_pipelines/__init__.py
Normal file
0
tests/modular_pipelines/__init__.py
Normal file
@@ -0,0 +1,511 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
ComponentsManager,
|
||||
ModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs_and_blocks(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
|
||||
assert "ip_adapter_image" in parameters, (
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
intermediate_parameters = blocks.intermediate_input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
|
||||
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_multi_adapter_scale > 1e-2, (
|
||||
"Output with multi-ip-adapter scale must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
|
||||
assert "controlnet_conditioning_scale" in parameters, (
|
||||
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = torch.randn(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
device=torch_device,
|
||||
)
|
||||
inputs["control_image"] = image
|
||||
return inputs
|
||||
|
||||
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for ControlNet.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single ControlNet with scale=0 should produce same output as no ControlNet.
|
||||
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
sd_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
sd_pipe = self.get_pipeline(components_manager=cm)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_xl_save_from_pretrained(self):
|
||||
pipes = []
|
||||
sd_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
sd_pipe.load_default_components(torch_dtype=torch.float32)
|
||||
sd_pipe.to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["mask_image"] = mask_image
|
||||
inputs["strength"] = 1.0
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
330
tests/modular_pipelines/test_modular_pipelines_common.py
Normal file
330
tests/modular_pipelines/test_modular_pipelines_common.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import gc
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
|
||||
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
|
||||
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
|
||||
- test_to_device: check if the pipeline's __call__ method can handle different devices
|
||||
"""
|
||||
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `params` in the child test class. "
|
||||
"`params` are checked for if all values are present in `__call__`'s signature."
|
||||
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
|
||||
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
|
||||
"image pipelines, including prompts and prompt embedding overrides."
|
||||
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
|
||||
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
|
||||
"with non-configurable height and width arguments should set the attribute as "
|
||||
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `batch_params` in the child test class. "
|
||||
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
|
||||
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
|
||||
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
|
||||
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
|
||||
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
|
||||
"image pipeline `negative_prompt` is not batched should set the attribute as "
|
||||
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
intermediate_parameters = pipe.blocks.intermediate_input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert (
|
||||
len(remaining_parameters) == 0
|
||||
), f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(
|
||||
device == torch_device for device in model_devices
|
||||
), "All pipeline components are not on accelerator device"
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
for component in base_pipe.components:
|
||||
assert component.device == torch_device
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
@@ -155,7 +155,7 @@ class FluxPipelineFastTests(
|
||||
|
||||
# Outputs should be different here
|
||||
# For some reasons, they don't show large differences
|
||||
assert max_diff > 1e-6
|
||||
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -187,14 +187,17 @@ class FluxPipelineFastTests(
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
@@ -209,7 +212,11 @@ class FluxPipelineFastTests(
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_flux_true_cfg(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
@@ -220,7 +227,9 @@ class FluxPipelineFastTests(
|
||||
inputs["negative_prompt"] = "bad quality"
|
||||
inputs["true_cfg_scale"] = 2.0
|
||||
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
|
||||
assert not np.allclose(no_true_cfg_out, true_cfg_out)
|
||||
self.assertFalse(
|
||||
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
|
||||
)
|
||||
|
||||
|
||||
@nightly
|
||||
@@ -269,45 +278,17 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3164,
|
||||
0.3164,
|
||||
0.3125,
|
||||
0.3125,
|
||||
0.3281,
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3301,
|
||||
0.3262,
|
||||
0.3242,
|
||||
0.3281,
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3262,
|
||||
0.3262,
|
||||
0.3164,
|
||||
0.3262,
|
||||
0.3281,
|
||||
0.3184,
|
||||
0.3281,
|
||||
0.3281,
|
||||
0.3203,
|
||||
0.3281,
|
||||
0.3281,
|
||||
0.3164,
|
||||
0.3320,
|
||||
0.3320,
|
||||
0.3203,
|
||||
],
|
||||
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
|
||||
dtype=np.float32,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
self.assertLess(
|
||||
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@@ -377,42 +358,14 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.1855,
|
||||
0.1680,
|
||||
0.1406,
|
||||
0.1953,
|
||||
0.1699,
|
||||
0.1465,
|
||||
0.2012,
|
||||
0.1738,
|
||||
0.1484,
|
||||
0.2051,
|
||||
0.1797,
|
||||
0.1523,
|
||||
0.2012,
|
||||
0.1719,
|
||||
0.1445,
|
||||
0.2070,
|
||||
0.1777,
|
||||
0.1465,
|
||||
0.2090,
|
||||
0.1836,
|
||||
0.1484,
|
||||
0.2129,
|
||||
0.1875,
|
||||
0.1523,
|
||||
0.2090,
|
||||
0.1816,
|
||||
0.1484,
|
||||
0.2110,
|
||||
0.1836,
|
||||
0.1543,
|
||||
],
|
||||
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
|
||||
dtype=np.float32,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
|
||||
self.assertLess(
|
||||
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
|
||||
)
|
||||
|
||||
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# Text guided image variation with an image mask
|
||||
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# image variation with an image mask
|
||||
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
"example_image",
|
||||
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
# image params
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
|
||||
# batch params
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
|
||||
|
||||
# callback params
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
class Bnb4BitCompileTests(QuantCompileTests):
|
||||
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
@@ -888,12 +888,7 @@ class Bnb4BitCompileTests(QuantCompileTests):
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
super().test_torch_compile()
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, use_stream=True
|
||||
)
|
||||
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
|
||||
|
||||
@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
class Bnb8BitCompileTests(QuantCompileTests):
|
||||
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
@@ -849,15 +849,11 @@ class Bnb8BitCompileTests(QuantCompileTests):
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
|
||||
super()._test_torch_compile(torch_dtype=torch.float16)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
|
||||
)
|
||||
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
|
||||
|
||||
@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class GGUFCompileTests(QuantCompileTests):
|
||||
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
torch_dtype = torch.bfloat16
|
||||
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
|
||||
@@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests):
|
||||
def quantization_config(self):
|
||||
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
|
||||
|
||||
def _init_pipeline(self, *args, **kwargs):
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import unittest
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class QuantCompileTests(unittest.TestCase):
|
||||
class QuantCompileTests:
|
||||
@property
|
||||
def quantization_config(self):
|
||||
raise NotImplementedError(
|
||||
@@ -50,30 +50,26 @@ class QuantCompileTests(unittest.TestCase):
|
||||
)
|
||||
return pipe
|
||||
|
||||
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
|
||||
# import to ensure fullgraph True
|
||||
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
|
||||
# `fullgraph=True` ensures no graph breaks
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_group_offload_leaf(
|
||||
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
|
||||
):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch.device("cuda"),
|
||||
"offload_device": torch.device("cpu"),
|
||||
@@ -87,6 +83,17 @@ class QuantCompileTests(unittest.TestCase):
|
||||
if torch.device(component.device).type == "cpu":
|
||||
component.to("cuda")
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def test_torch_compile(self):
|
||||
self._test_torch_compile()
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
self._test_torch_compile_with_cpu_offload()
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
|
||||
return
|
||||
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
|
||||
@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoCompileTest(QuantCompileTests):
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
@@ -639,17 +639,15 @@ class TorchAoCompileTest(QuantCompileTests):
|
||||
},
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
super().test_torch_compile_with_cpu_offload()
|
||||
|
||||
@parameterized.expand([False, True])
|
||||
@unittest.skip(
|
||||
"""
|
||||
For `use_stream=False`:
|
||||
@@ -659,8 +657,7 @@ class TorchAoCompileTest(QuantCompileTests):
|
||||
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
"""
|
||||
)
|
||||
@parameterized.expand([False, True])
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
def test_torch_compile_with_group_offload_leaf(self, use_stream):
|
||||
# For use_stream=False:
|
||||
# If we run group offloading without compilation, we will see:
|
||||
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
@@ -673,7 +670,7 @@ class TorchAoCompileTest(QuantCompileTests):
|
||||
|
||||
# For use_stream=True:
|
||||
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
|
||||
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
|
||||
Reference in New Issue
Block a user