Compare commits

..

16 Commits

Author SHA1 Message Date
Sayak Paul
448dc47c9a Merge branch 'main' into fix-hotswapping-ci 2025-07-16 08:20:10 +01:00
G.O.D
5c5209720e enable flux pipeline compatible with unipc and dpm-solver (#11908)
* Update pipeline_flux.py

have flux pipeline work with unipc/dpm schedulers

* clean code

* Update scheduling_dpmsolver_multistep.py

* Update scheduling_unipc_multistep.py

* Update pipeline_flux.py

* Update scheduling_deis_multistep.py

* Update scheduling_dpmsolver_singlestep.py

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-07-15 17:49:57 -10:00
Álvaro Somoza
aa14f090f8 [ControlnetUnion] Propagate #11888 to img2img (#11929)
img2img fixes
2025-07-15 21:41:35 -04:00
Guoqing Zhu
c5d6e0b537 Fixed bug: Uncontrolled recursive calls that caused an infinite loop when loading certain pipelines containing Transformer2DModel (#11923)
* fix a bug about loop call

* fix a bug about loop call

* ruff format

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-07-15 14:58:37 -10:00
lostdisc
39831599f1 Remove forced float64 from onnx stable diffusion pipelines (#11054)
* Update pipeline_onnx_stable_diffusion.py to remove float64

init_noise_sigma was being set as float64 before multiplying with latents, which changed latents into float64 too, which caused errors with onnxruntime since the latter wanted float16.

* Update pipeline_onnx_stable_diffusion_inpaint.py to remove float64

init_noise_sigma was being set as float64 before multiplying with latents, which changed latents into float64 too, which caused errors with onnxruntime since the latter wanted float16.

* Update pipeline_onnx_stable_diffusion_upscale.py to remove float64

init_noise_sigma was being set as float64 before multiplying with latents, which changed latents into float64 too, which caused errors with onnxruntime since the latter wanted float16.

* Update pipeline_onnx_stable_diffusion.py with comment for previous commit

Added comment on purpose of init_noise_sigma.  This comment exists in related scripts that use the same line of code, but it was missing here.

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-07-15 14:57:28 -10:00
Aryan
b73c738392 Remove device synchronization when loading weights (#11927)
* update

* make style
2025-07-15 21:40:57 +05:30
Aryan
06fd427797 [tests] Improve Flux tests (#11919)
update
2025-07-15 10:47:41 +05:30
dependabot[bot]
48a551251d Bump aiohttp from 3.10.10 to 3.12.14 in /examples/server (#11924)
Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.10.10 to 3.12.14.
- [Release notes](https://github.com/aio-libs/aiohttp/releases)
- [Changelog](https://github.com/aio-libs/aiohttp/blob/master/CHANGES.rst)
- [Commits](https://github.com/aio-libs/aiohttp/compare/v3.10.10...v3.12.14)

---
updated-dependencies:
- dependency-name: aiohttp
  dependency-version: 3.12.14
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-15 09:15:57 +05:30
Hengyue-Bi
6398fbc391 Fix: Align VAE processing in ControlNet SD3 training with inference (#11909)
Fix: Apply vae_shift_factor in ControlNet SD3 training
2025-07-14 14:54:38 -04:00
Colle
3c8b67b371 Flux: pass joint_attention_kwargs when using gradient_checkpointing (#11814)
Flux: pass joint_attention_kwargs when gradient_checkpointing
2025-07-11 08:35:18 -10:00
Steven Liu
9feb946432 [docs] torch.compile blog post (#11837)
* add blog post

* feedback

* feedback
2025-07-11 10:29:40 -07:00
Aryan
c90352754a Speedup model loading by 4-5x (#11904)
* update

* update

* update

* pin accelerate version

* add comment explanations

* update docstring

* make style

* non_blocking does not matter for dtype cast

* _empty_cache -> clear_cache

* update

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/diffusers/models/model_loading_utils.py

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-07-11 21:43:53 +05:30
Sayak Paul
7a935a0bbe [tests] Unify compilation + offloading tests in quantization (#11910)
* unify the quant compile + offloading tests.

* fix

* update
2025-07-11 17:02:29 +05:30
chenxiao
941b7fc084 Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)
* Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2025-07-10 11:51:05 -10:00
Álvaro Somoza
76a62ac9cc [ControlnetUnion] Multiple Fixes (#11888)
fixes

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-07-10 14:35:28 -04:00
sayakpaul
d2cbc6365f enable hotswapping tests on our nightly CI. 2025-06-28 13:20:44 +05:30
32 changed files with 476 additions and 306 deletions

View File

@@ -193,24 +193,28 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run torch compile tests on GPU
- name: Run torch hotswap + compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "hotswap" --make-reports=tests_torch_hotswap_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
run: |
cat reports/tests_torch_compile_cuda_failures_short.txt
cat reports/tests_torch_hotswap_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
name: torch_compile_hotswap_test_reports
path: reports
run_big_gpu_torch_tests:

View File

@@ -189,6 +189,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
RUN_SLOW: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports

View File

@@ -174,39 +174,36 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 810x.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
# compile only the repeated transformer layers inside the UNet
pipeline.unet.compile_repeated_blocks(fullgraph=True)
```
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
> [!TIP]
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py
# pip install -U accelerate
@@ -219,8 +216,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
### Graph breaks
@@ -296,3 +293,9 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
```
## Resources
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).

View File

@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
> [!TIP]
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -25,7 +28,7 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.

View File

@@ -1330,7 +1330,7 @@ def main(args):
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
controlnet_image = controlnet_image * vae.config.scaling_factor
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
control_block_res_samples = controlnet(
hidden_states=noisy_model_input,

View File

@@ -1,10 +1,10 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.4.3
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.10.10
aiohttp==3.12.14
# via -r requirements.in
aiosignal==1.3.1
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
# huggingface-hub
# torch
# transformers
# triton
frozenlist==1.5.0
# via
# aiohttp
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
# via yarl
# via
# aiohttp
# yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
@@ -155,7 +156,9 @@ triton==3.3.0
# via torch
typing-extensions==4.12.2
# via
# aiosignal
# anyio
# exceptiongroup
# fastapi
# huggingface-hub
# multidict
@@ -168,5 +171,5 @@ urllib3==2.5.0
# via requests
uvicorn==0.32.0
# via -r requirements.in
yarl==1.16.0
yarl==1.18.3
# via aiohttp

View File

@@ -763,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
if remapped_class is cls:
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
else:
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)

View File

@@ -24,6 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -430,6 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
empty_device_cache()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

View File

@@ -46,6 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
@@ -1689,6 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2148,6 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint)

View File

@@ -18,11 +18,8 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
if is_accelerate_available():
@@ -84,6 +81,7 @@ class FluxTransformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection
@@ -158,6 +156,8 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

View File

@@ -18,6 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
@@ -80,6 +81,8 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
empty_device_cache()
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +150,7 @@ class SD3Transformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_proj

View File

@@ -43,6 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -753,6 +754,7 @@ class UNet2DConditionLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
return image_projection
@@ -850,6 +852,8 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

View File

@@ -16,9 +16,10 @@
import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,6 +39,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -252,6 +254,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)

View File

@@ -62,10 +62,14 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
@@ -1469,11 +1473,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -1482,18 +1481,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1511,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
)
if low_cpu_mem_usage:
@@ -1554,9 +1538,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
empty_device_cache()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
@@ -1892,4 +1877,9 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
if remapped_class is cls:
return super(LegacyModelMixin, remapped_class).from_pretrained(
pretrained_model_name_or_path, **kwargs_copy
)
else:
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)

View File

@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)

View File

@@ -490,6 +490,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
@@ -521,6 +522,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:

View File

@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -38,7 +37,13 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
from ...models import (
AutoencoderKL,
ControlNetUnionModel,
ImageProjection,
MultiControlNetUnionModel,
UNet2DConditionModel,
)
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetUnionModel,
controlnet: Union[
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
],
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
@@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
):
super().__init__()
if not isinstance(controlnet, ControlNetUnionModel):
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules(
vae=vae,
@@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
control_mode=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
@@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetUnionModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
self.check_image(image, prompt, prompt_embeds)
elif (
isinstance(self.controlnet, ControlNetUnionModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
):
self.check_image(image, prompt, prompt_embeds)
else:
assert False
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if isinstance(controlnet, ControlNetUnionModel):
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
elif isinstance(controlnet, MultiControlNetUnionModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
elif not all(isinstance(i, list) for i in image):
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)
for images_ in image:
for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
if isinstance(controlnet, MultiControlNetUnionModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end]
@@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Check `control_mode`
if isinstance(controlnet, ControlNetUnionModel):
if max(control_mode) >= controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
elif isinstance(controlnet, MultiControlNetUnionModel):
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
control_image: PipelineImageInput = None,
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
@@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
@@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`PipelineImageInput`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
init, images must be passed as a list such that each element of the list can be correctly batched for
input to a single controlnet.
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
height (`int`, *optional*, defaults to the size of control_image):
The height in pixels of the generated image. Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
The percentage of total steps at which the ControlNet stops applying.
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
where each ControlNet should have its corresponding control mode list. Should reflect the order of
conditions in control_image
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
if isinstance(controlnet, MultiControlNetUnionModel):
control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
control_type = torch.Tensor(control_type)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs
self.check_inputs(
prompt,
prompt_2,
control_image,
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
control_mode,
callback_on_step_end_tensor_inputs,
)
if isinstance(controlnet, ControlNetUnionModel):
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
]
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
@@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
device = self._execution_device
global_pool_conditions = controlnet.config.global_pool_conditions
global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetUnionModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3.1. Encode input prompt
@@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
self.do_classifier_free_guidance,
)
# 4. Prepare image and controlnet_conditioning_image
# 4.1 Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 4.2 Prepare control images
if isinstance(controlnet, ControlNetUnionModel):
control_images = []
for image_ in control_image:
image_ = self.prepare_control_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
control_images.append(image_)
control_image = control_images
height, width = control_image[0].shape[-2:]
elif isinstance(controlnet, MultiControlNetUnionModel):
control_images = []
for control_image_ in control_image:
images = []
for image_ in control_image_:
image_ = self.prepare_control_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
control_images.append(images)
control_image = control_images
height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
controlnet_keep.append(
1.0
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
)
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
@@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)
control_type = (
control_type.reshape(1, -1)
.to(device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
control_type_repeat_factor = (
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
)
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
)
elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
.repeat(control_type_repeat_factor, 1)
for _control_type in control_type
]
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:

View File

@@ -840,6 +840,8 @@ class FluxPipeline(
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,

View File

@@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * np.float64(self.scheduler.init_noise_sigma)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

View File

@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(self.scheduler.init_noise_sigma)
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

View File

@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# Scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(self.scheduler.init_noise_sigma)
latents = latents * self.scheduler.init_noise_sigma
# 5. Add noise to image
noise_level = np.array([noise_level]).astype(np.int64)

View File

@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (

View File

@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:

View File

@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:

View File

@@ -212,6 +212,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -298,7 +300,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -309,6 +313,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)

View File

@@ -184,5 +184,14 @@ def get_device():
def empty_device_cache(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
if device_type in ["cpu"]:
return
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache()
def device_synchronize(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()

View File

@@ -155,7 +155,7 @@ class FluxPipelineFastTests(
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -187,14 +187,17 @@ class FluxPipelineFastTests(
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
self.assertTrue(
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
("Fusion of QKV projections shouldn't affect the outputs."),
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
self.assertTrue(
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
self.assertTrue(
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
("Original outputs should match when fused QKV projections are disabled."),
)
def test_flux_image_output_shape(self):
@@ -209,7 +212,11 @@ class FluxPipelineFastTests(
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
self.assertEqual(
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -220,7 +227,9 @@ class FluxPipelineFastTests(
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
assert not np.allclose(no_true_cfg_out, true_cfg_out)
self.assertFalse(
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
)
@nightly
@@ -269,45 +278,17 @@ class FluxPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array(
[
0.3242,
0.3203,
0.3164,
0.3164,
0.3125,
0.3125,
0.3281,
0.3242,
0.3203,
0.3301,
0.3262,
0.3242,
0.3281,
0.3242,
0.3203,
0.3262,
0.3262,
0.3164,
0.3262,
0.3281,
0.3184,
0.3281,
0.3281,
0.3203,
0.3281,
0.3281,
0.3164,
0.3320,
0.3320,
0.3203,
],
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
dtype=np.float32,
)
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
self.assertLess(
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
)
@slow
@@ -377,42 +358,14 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
expected_slice = np.array(
[
0.1855,
0.1680,
0.1406,
0.1953,
0.1699,
0.1465,
0.2012,
0.1738,
0.1484,
0.2051,
0.1797,
0.1523,
0.2012,
0.1719,
0.1445,
0.2070,
0.1777,
0.1465,
0.2090,
0.1836,
0.1484,
0.2129,
0.1875,
0.1523,
0.2090,
0.1816,
0.1484,
0.2110,
0.1836,
0.1543,
],
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
dtype=np.float32,
)
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
self.assertLess(
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
)

View File

@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests):
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
@@ -888,12 +888,7 @@ class Bnb4BitCompileTests(QuantCompileTests):
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
super().test_torch_compile()
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)

View File

@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests):
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
@@ -849,15 +849,11 @@ class Bnb8BitCompileTests(QuantCompileTests):
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
super()._test_torch_compile(torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
)
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)

View File

@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests):
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
@@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests):
def quantization_config(self):
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
def _init_pipeline(self, *args, **kwargs):
transformer = FluxTransformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import inspect
import torch
@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
class QuantCompileTests:
@property
def quantization_config(self):
raise NotImplementedError(
@@ -50,30 +50,26 @@ class QuantCompileTests(unittest.TestCase):
)
return pipe
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
# `fullgraph=True` ensures no graph breaks
pipe.transformer.compile(fullgraph=True)
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload_leaf(
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
torch._dynamo.config.cache_size_limit = 10000
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
torch._dynamo.config.cache_size_limit = 1000
pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
@@ -87,6 +83,17 @@ class QuantCompileTests(unittest.TestCase):
if torch.device(component.device).type == "cpu":
component.to("cuda")
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def test_torch_compile(self):
self._test_torch_compile()
def test_torch_compile_with_cpu_offload(self):
self._test_torch_compile_with_cpu_offload()
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
for cls in inspect.getmro(self.__class__):
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
return
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)

View File

@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase):
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
@@ -639,17 +639,15 @@ class TorchAoCompileTest(QuantCompileTests):
},
)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
super().test_torch_compile_with_cpu_offload()
@parameterized.expand([False, True])
@unittest.skip(
"""
For `use_stream=False`:
@@ -659,8 +657,7 @@ class TorchAoCompileTest(QuantCompileTests):
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
@parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
def test_torch_compile_with_group_offload_leaf(self, use_stream):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
@@ -673,7 +670,7 @@ class TorchAoCompileTest(QuantCompileTests):
# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners