mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-28 14:35:00 +08:00
Compare commits
1 Commits
modular-cu
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36f4085772 |
@@ -174,36 +174,39 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
|
||||
|
||||
### Regional compilation
|
||||
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
|
||||
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.
|
||||
|
||||
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
|
||||
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
|
||||
|
||||
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
|
||||
|
||||
```py
|
||||
# pip install -U diffusers
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
# compile only the repeated transformer layers inside the UNet
|
||||
pipeline.unet.compile_repeated_blocks(fullgraph=True)
|
||||
# Compile only the repeated Transformer layers inside the UNet
|
||||
pipe.unet.compile_repeated_blocks(fullgraph=True)
|
||||
```
|
||||
|
||||
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
|
||||
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
|
||||
|
||||
|
||||
```py
|
||||
class MyUNet(ModelMixin):
|
||||
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
|
||||
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
|
||||
|
||||
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
|
||||
|
||||
|
||||
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
|
||||
|
||||
```py
|
||||
# pip install -U accelerate
|
||||
@@ -216,8 +219,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
|
||||
|
||||
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
|
||||
|
||||
### Graph breaks
|
||||
|
||||
@@ -293,9 +296,3 @@ An input is projected into three subspaces, represented by the projection matric
|
||||
```py
|
||||
pipeline.fuse_qkv_projections()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
|
||||
|
||||
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
|
||||
@@ -14,9 +14,6 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
|
||||
|
||||
> [!TIP]
|
||||
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
|
||||
|
||||
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
|
||||
|
||||
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
|
||||
@@ -28,7 +25,7 @@ The table below provides a comparison of optimization strategy combinations and
|
||||
| quantization | 32.602 | 14.9453 |
|
||||
| quantization, torch.compile | 25.847 | 14.9448 |
|
||||
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
|
||||
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
|
||||
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
|
||||
|
||||
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
|
||||
|
||||
|
||||
@@ -1330,7 +1330,7 @@ def main(args):
|
||||
# controlnet(s) inference
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
|
||||
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
controlnet_image = controlnet_image * vae.config.scaling_factor
|
||||
|
||||
control_block_res_samples = controlnet(
|
||||
hidden_states=noisy_model_input,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in -o requirements.txt
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohappyeyeballs==2.4.3
|
||||
# via aiohttp
|
||||
aiohttp==3.12.14
|
||||
aiohttp==3.10.10
|
||||
# via -r requirements.in
|
||||
aiosignal==1.4.0
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
@@ -110,9 +110,7 @@ prometheus-client==0.21.0
|
||||
prometheus-fastapi-instrumentator==7.0.0
|
||||
# via -r requirements.in
|
||||
propcache==0.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
# via yarl
|
||||
py-consul==1.5.3
|
||||
# via -r requirements.in
|
||||
pydantic==2.9.2
|
||||
@@ -125,7 +123,7 @@ pyyaml==6.0.2
|
||||
# transformers
|
||||
regex==2024.9.11
|
||||
# via transformers
|
||||
requests==2.32.3
|
||||
requests==2.32.4
|
||||
# via
|
||||
# huggingface-hub
|
||||
# py-consul
|
||||
@@ -156,7 +154,6 @@ triton==3.3.0
|
||||
# via torch
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
@@ -171,5 +168,5 @@ urllib3==2.5.0
|
||||
# via requests
|
||||
uvicorn==0.32.0
|
||||
# via -r requirements.in
|
||||
yarl==1.18.3
|
||||
yarl==1.16.0
|
||||
# via aiohttp
|
||||
|
||||
@@ -470,7 +470,7 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from typing_extensions import Self
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
@@ -431,10 +430,6 @@ class FromOriginalModelMixin:
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1690,10 +1689,6 @@ def create_diffusers_clip_model_from_ldm(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
@@ -2153,10 +2148,6 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
|
||||
@@ -18,8 +18,11 @@ from ..models.embeddings import (
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -81,8 +84,6 @@ class FluxTransformer2DLoadersMixin:
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -157,9 +158,6 @@ class FluxTransformer2DLoadersMixin:
|
||||
|
||||
key_id += 1
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
|
||||
@@ -18,7 +18,6 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -81,9 +80,6 @@ class SD3Transformer2DLoadersMixin:
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||
@@ -151,8 +147,6 @@ class SD3Transformer2DLoadersMixin:
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_proj
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
@@ -754,8 +753,6 @@ class UNet2DConditionLoadersMixin:
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -853,9 +850,6 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
key_id += 2
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
|
||||
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
condition = self.controlnet_cond_embedding(cond)
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[control_idx]
|
||||
if from_multi or len(control_type_idx) == 1:
|
||||
if from_multi:
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
else:
|
||||
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
if from_multi or len(control_type_idx) == 1:
|
||||
if from_multi:
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
else:
|
||||
controlnet_cond_fuser += condition + alpha * scale
|
||||
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
if from_multi or len(control_type_idx) == 1:
|
||||
if from_multi:
|
||||
scales = scales * conditioning_scale[0]
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
elif from_multi or len(control_type_idx) == 1:
|
||||
elif from_multi:
|
||||
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
||||
|
||||
|
||||
@@ -16,10 +16,9 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
@@ -39,7 +38,6 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
@@ -254,10 +252,6 @@ def load_model_dict_into_meta(
|
||||
param = param.to(dtype)
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
|
||||
if is_accelerate_version(">", "1.8.1"):
|
||||
set_module_kwargs["non_blocking"] = True
|
||||
set_module_kwargs["clear_cache"] = False
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
@@ -526,60 +520,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
||||
|
||||
return parsed_parameters
|
||||
|
||||
|
||||
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
|
||||
mismatched_keys = []
|
||||
if not ignore_mismatched_sizes:
|
||||
return mismatched_keys
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
|
||||
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
|
||||
def _expand_device_map(device_map, param_names):
|
||||
"""
|
||||
Expand a device map to return the correspondence parameter name to device.
|
||||
"""
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
||||
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
|
||||
"""
|
||||
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
|
||||
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
|
||||
very large margin.
|
||||
"""
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device)
|
||||
for param, device in expanded_device_map.items()
|
||||
if str(device) not in ["cpu", "disk"]
|
||||
}
|
||||
parameter_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += math.prod(param.shape)
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, param_count in parameter_count.items():
|
||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
@@ -62,14 +62,10 @@ from ..utils.hub_utils import (
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
_expand_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_find_mismatched_keys,
|
||||
_load_state_dict_into_model,
|
||||
load_model_dict_into_meta,
|
||||
load_state_dict,
|
||||
@@ -1473,6 +1469,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
mismatched_keys = []
|
||||
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
@@ -1481,27 +1482,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
else:
|
||||
if offload_folder is not None:
|
||||
os.makedirs(offload_folder, exist_ok=True)
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
|
||||
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
|
||||
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
|
||||
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
|
||||
# tensors using their expected shape and not performing any initialization of the memory (empty data).
|
||||
# When the actual device allocations happen, the allocator already has a pool of unused device memory
|
||||
# that it can re-use for faster loading of the model.
|
||||
# TODO: add support for warmup with hf_quantizer
|
||||
if device_map is not None and hf_quantizer is None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, dtype)
|
||||
|
||||
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
||||
state_dict_folder, state_dict_index = None, None
|
||||
if offload_state_dict:
|
||||
state_dict_folder = tempfile.mkdtemp()
|
||||
state_dict_index = {}
|
||||
else:
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
|
||||
if state_dict is not None:
|
||||
# load_state_dict will manage the case where we pass a dict instead of a file
|
||||
@@ -1511,14 +1503,38 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if len(resolved_model_file) > 1:
|
||||
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
|
||||
mismatched_keys = []
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
for shard_file in resolved_model_file:
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -1538,12 +1554,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
if assign_to_params_buffers is None:
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
|
||||
@@ -187,15 +187,9 @@ class CosmosAttnProcessor2_0:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
|
||||
# 4. Prepare for GQA
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
|
||||
else:
|
||||
query_idx = query.size(3)
|
||||
key_idx = key.size(3)
|
||||
value_idx = value.size(3)
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
|
||||
@@ -490,7 +490,6 @@ class FluxTransformer2DModel(
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -522,7 +521,6 @@ class FluxTransformer2DModel(
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -323,7 +323,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
model_name = None
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
@@ -334,14 +333,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -367,9 +358,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
raise ValueError("TODO")
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
@@ -378,6 +367,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
is_modular=True,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ class ComponentSpec:
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
@@ -34,13 +35,7 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetUnionModel,
|
||||
ImageProjection,
|
||||
MultiControlNetUnionModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
@@ -235,9 +230,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: Union[
|
||||
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
|
||||
],
|
||||
controlnet: ControlNetUnionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
@@ -247,8 +240,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetUnionModel(controlnet)
|
||||
if not isinstance(controlnet, ControlNetUnionModel):
|
||||
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -667,7 +660,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
control_mode=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
@@ -755,34 +747,25 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
elif (
|
||||
isinstance(self.controlnet, ControlNetUnionModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
for image_ in image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not isinstance(image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
elif not all(isinstance(i, list) for i in image):
|
||||
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if not isinstance(control_guidance_start, (tuple, list)):
|
||||
control_guidance_start = [control_guidance_start]
|
||||
@@ -795,12 +778,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
||||
)
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if len(control_guidance_start) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
||||
)
|
||||
|
||||
for start, end in zip(control_guidance_start, control_guidance_end):
|
||||
if start >= end:
|
||||
raise ValueError(
|
||||
@@ -811,28 +788,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Check `control_mode`
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if max(control_mode) >= controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
|
||||
# Equal number of `image` and `control_mode` elements
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if len(image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_mode)")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if not all(isinstance(i, list) for i in control_mode):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
|
||||
)
|
||||
|
||||
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_mode)")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
@@ -1162,7 +1117,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
@@ -1190,7 +1145,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
@@ -1222,13 +1177,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
@@ -1321,22 +1269,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
||||
the corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
||||
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
|
||||
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
|
||||
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
|
||||
where each ControlNet should have its corresponding control mode list. Should reflect the order of
|
||||
conditions in control_image.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
||||
@@ -1401,6 +1333,22 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
# # 0.0 Default height and width to unet
|
||||
# height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
# width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 0.1 align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1409,59 +1357,40 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
if len(control_image) != len(control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
num_control_type = controlnet.config.num_control_type
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
control_mode,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
control_type = [0 for _ in range(num_control_type)]
|
||||
for _image, control_idx in zip(control_image, control_mode):
|
||||
control_type[control_idx] = 1
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
_image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
|
||||
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
|
||||
]
|
||||
control_type = torch.Tensor(control_type)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
@@ -1554,55 +1483,21 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 5.2 Prepare control images
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for image_ in control_image:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
control_images.append(image_)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0].shape[-2:]
|
||||
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
images = []
|
||||
|
||||
for image_ in control_image_:
|
||||
image_ = self.prepare_control_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
control_images.append(images)
|
||||
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
for idx, _ in enumerate(control_image):
|
||||
control_image[idx] = self.prepare_control_image(
|
||||
image=control_image[idx],
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = control_image[idx].shape[-2:]
|
||||
|
||||
# 5.3 Prepare mask
|
||||
mask = self.mask_processor.preprocess(
|
||||
@@ -1664,11 +1559,10 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
# 8.2 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps)
|
||||
controlnet_keep.append(
|
||||
1.0
|
||||
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
||||
)
|
||||
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
height, width = latents.shape[-2:]
|
||||
@@ -1733,24 +1627,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(device, dtype=prompt_embeds.dtype)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
|
||||
@@ -1452,21 +1452,17 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
|
||||
control_type_repeat_factor = (
|
||||
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
||||
)
|
||||
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
control_type = (
|
||||
control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_type = [
|
||||
_control_type.reshape(1, -1)
|
||||
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
||||
.repeat(control_type_repeat_factor, 1)
|
||||
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
||||
for _control_type in control_type
|
||||
]
|
||||
|
||||
|
||||
@@ -175,8 +175,6 @@ def get_device():
|
||||
return "npu"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
elif torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
@@ -184,14 +182,5 @@ def get_device():
|
||||
def empty_device_cache(device_type: Optional[str] = None):
|
||||
if device_type is None:
|
||||
device_type = get_device()
|
||||
if device_type in ["cpu"]:
|
||||
return
|
||||
device_mod = getattr(torch, device_type, torch.cuda)
|
||||
device_mod.empty_cache()
|
||||
|
||||
|
||||
def device_synchronize(device_type: Optional[str] = None):
|
||||
if device_type is None:
|
||||
device_type = get_device()
|
||||
device_mod = getattr(torch, device_type, torch.cuda)
|
||||
device_mod.synchronize()
|
||||
|
||||
@@ -2510,34 +2510,3 @@ class PeftLoraLoaderMixinTests:
|
||||
# materializes the test methods on invocation which cannot be overridden.
|
||||
return
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_lora_loading_model_cpu_offload(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
# reinitialize the pipeline to mimic the inference workflow.
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@@ -155,7 +155,7 @@ class FluxPipelineFastTests(
|
||||
|
||||
# Outputs should be different here
|
||||
# For some reasons, they don't show large differences
|
||||
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -187,17 +187,14 @@ class FluxPipelineFastTests(
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
@@ -212,11 +209,7 @@ class FluxPipelineFastTests(
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
def test_flux_true_cfg(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
@@ -227,9 +220,7 @@ class FluxPipelineFastTests(
|
||||
inputs["negative_prompt"] = "bad quality"
|
||||
inputs["true_cfg_scale"] = 2.0
|
||||
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
|
||||
self.assertFalse(
|
||||
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
|
||||
)
|
||||
assert not np.allclose(no_true_cfg_out, true_cfg_out)
|
||||
|
||||
|
||||
@nightly
|
||||
@@ -278,17 +269,45 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
|
||||
[
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3164,
|
||||
0.3164,
|
||||
0.3125,
|
||||
0.3125,
|
||||
0.3281,
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3301,
|
||||
0.3262,
|
||||
0.3242,
|
||||
0.3281,
|
||||
0.3242,
|
||||
0.3203,
|
||||
0.3262,
|
||||
0.3262,
|
||||
0.3164,
|
||||
0.3262,
|
||||
0.3281,
|
||||
0.3184,
|
||||
0.3281,
|
||||
0.3281,
|
||||
0.3203,
|
||||
0.3281,
|
||||
0.3281,
|
||||
0.3164,
|
||||
0.3320,
|
||||
0.3320,
|
||||
0.3203,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
self.assertLess(
|
||||
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
|
||||
)
|
||||
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@@ -358,14 +377,42 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
|
||||
[
|
||||
0.1855,
|
||||
0.1680,
|
||||
0.1406,
|
||||
0.1953,
|
||||
0.1699,
|
||||
0.1465,
|
||||
0.2012,
|
||||
0.1738,
|
||||
0.1484,
|
||||
0.2051,
|
||||
0.1797,
|
||||
0.1523,
|
||||
0.2012,
|
||||
0.1719,
|
||||
0.1445,
|
||||
0.2070,
|
||||
0.1777,
|
||||
0.1465,
|
||||
0.2090,
|
||||
0.1836,
|
||||
0.1484,
|
||||
0.2129,
|
||||
0.1875,
|
||||
0.1523,
|
||||
0.2090,
|
||||
0.1816,
|
||||
0.1484,
|
||||
0.2110,
|
||||
0.1836,
|
||||
0.1543,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
self.assertLess(
|
||||
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
|
||||
)
|
||||
|
||||
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
|
||||
|
||||
@@ -873,11 +873,11 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
class Bnb4BitCompileTests(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
@@ -888,7 +888,12 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super().test_torch_compile()
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, use_stream=True
|
||||
)
|
||||
|
||||
@@ -838,7 +838,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
class Bnb8BitCompileTests(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
@@ -849,11 +849,15 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(torch_dtype=torch.float16)
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
|
||||
super()._test_torch_compile_with_cpu_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
|
||||
)
|
||||
|
||||
@@ -654,7 +654,7 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
class GGUFCompileTests(QuantCompileTests):
|
||||
torch_dtype = torch.bfloat16
|
||||
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
|
||||
@@ -662,6 +662,15 @@ class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
def quantization_config(self):
|
||||
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
|
||||
|
||||
def _init_pipeline(self, *args, **kwargs):
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class QuantCompileTests:
|
||||
class QuantCompileTests(unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
raise NotImplementedError(
|
||||
@@ -50,26 +50,30 @@ class QuantCompileTests:
|
||||
)
|
||||
return pipe
|
||||
|
||||
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
|
||||
# `fullgraph=True` ensures no graph breaks
|
||||
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
|
||||
# import to ensure fullgraph True
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
def _test_torch_compile_with_group_offload_leaf(
|
||||
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
|
||||
):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch.device("cuda"),
|
||||
"offload_device": torch.device("cpu"),
|
||||
@@ -83,17 +87,6 @@ class QuantCompileTests:
|
||||
if torch.device(component.device).type == "cpu":
|
||||
component.to("cuda")
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def test_torch_compile(self):
|
||||
self._test_torch_compile()
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
self._test_torch_compile_with_cpu_offload()
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
|
||||
return
|
||||
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
@@ -630,7 +630,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
class TorchAoCompileTest(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
@@ -639,15 +639,17 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super().test_torch_compile_with_cpu_offload()
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
@parameterized.expand([False, True])
|
||||
@unittest.skip(
|
||||
"""
|
||||
For `use_stream=False`:
|
||||
@@ -657,7 +659,8 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
"""
|
||||
)
|
||||
def test_torch_compile_with_group_offload_leaf(self, use_stream):
|
||||
@parameterized.expand([False, True])
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
# For use_stream=False:
|
||||
# If we run group offloading without compilation, we will see:
|
||||
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
@@ -670,7 +673,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
|
||||
# For use_stream=True:
|
||||
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
|
||||
Reference in New Issue
Block a user