mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[Feat] TaylorSeer Cache (#12648)
* init taylor_seer cache * make compatible with any tuple size returned * use logger for printing, add warmup feature * still update in warmup steps * refractor, add docs * add configurable cache, skip compute module * allow special cache ids only * add stop_predicts (cooldown) * update docs * apply ruff * update to handle multple calls per timestep * refractor to use state manager * fix format & doc * chores: naming, remove redundancy * add docs * quality & style * fix taylor precision * Apply style fixes * add tests * Apply style fixes * Remove TaylorSeerCacheTesterMixin from flux2 tests * rename identifiers, use more expressive taylor predict loop * torch compile compatible * Apply style fixes * Update src/diffusers/hooks/taylorseer_cache.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * update docs * make fix-copies * fix example usage. * remove tests on flux kontext --------- Co-authored-by: toilaluan <toilaluan@github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
### TaylorSeerCacheConfig
|
||||
|
||||
[[autodoc]] TaylorSeerCacheConfig
|
||||
|
||||
[[autodoc]] apply_taylorseer_cache
|
||||
|
||||
@@ -66,4 +66,35 @@ config = FasterCacheConfig(
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## TaylorSeer Cache
|
||||
|
||||
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
|
||||
|
||||
This caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080).
|
||||
|
||||
To enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer:
|
||||
|
||||
- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass
|
||||
- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations
|
||||
- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = TaylorSeerCacheConfig(
|
||||
cache_interval=5,
|
||||
max_order=1,
|
||||
disable_cache_before_step=10,
|
||||
taylor_factors_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
@@ -169,10 +169,12 @@ else:
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
]
|
||||
)
|
||||
_import_structure["models"].extend(
|
||||
@@ -899,10 +901,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
|
||||
@@ -25,3 +25,4 @@ if is_torch_available():
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
346
src/diffusers/hooks/taylorseer_cache.py
Normal file
346
src/diffusers/hooks/taylorseer_cache.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_TAYLORSEER_CACHE_HOOK = "taylorseer_cache"
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
"^blocks.*attn",
|
||||
"^transformer_blocks.*attn",
|
||||
"^single_transformer_blocks.*attn",
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",)
|
||||
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaylorSeerCacheConfig:
|
||||
"""
|
||||
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
|
||||
|
||||
Attributes:
|
||||
cache_interval (`int`, defaults to `5`):
|
||||
The interval between full computation steps. After a full computation, the cached (predicted) outputs are
|
||||
reused for this many subsequent denoising steps before refreshing with a new full forward pass.
|
||||
|
||||
disable_cache_before_step (`int`, defaults to `3`):
|
||||
The denoising step index before which caching is disabled, meaning full computation is performed for the
|
||||
initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During
|
||||
these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this
|
||||
step.
|
||||
|
||||
disable_cache_after_step (`int`, *optional*, defaults to `None`):
|
||||
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run
|
||||
full computations without predictions or state updates, ensuring accuracy in later stages if needed.
|
||||
|
||||
max_order (`int`, defaults to `1`):
|
||||
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide
|
||||
better approximations but increase computation and memory usage.
|
||||
|
||||
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
||||
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
|
||||
affect stability; higher precision improves accuracy at the cost of more memory.
|
||||
|
||||
skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode,
|
||||
the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded
|
||||
shape) during prediction steps to skip computation cheaply.
|
||||
|
||||
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
|
||||
outputs are approximated and cached for reuse.
|
||||
|
||||
use_lite_mode (`bool`, *optional*, defaults to `False`):
|
||||
Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
|
||||
skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom
|
||||
`inactive_identifiers` or `active_identifiers`.
|
||||
|
||||
Notes:
|
||||
- Patterns are matched using `re.fullmatch` on the module name.
|
||||
- If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked.
|
||||
- If neither is provided, all attention-like modules are hooked by default.
|
||||
|
||||
Example of inactive and active usage:
|
||||
|
||||
```py
|
||||
def forward(x):
|
||||
x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
|
||||
x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
|
||||
return x
|
||||
```
|
||||
"""
|
||||
|
||||
cache_interval: int = 5
|
||||
disable_cache_before_step: int = 3
|
||||
disable_cache_after_step: Optional[int] = None
|
||||
max_order: int = 1
|
||||
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
|
||||
skip_predict_identifiers: Optional[List[str]] = None
|
||||
cache_identifiers: Optional[List[str]] = None
|
||||
use_lite_mode: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"TaylorSeerCacheConfig("
|
||||
f"cache_interval={self.cache_interval}, "
|
||||
f"disable_cache_before_step={self.disable_cache_before_step}, "
|
||||
f"disable_cache_after_step={self.disable_cache_after_step}, "
|
||||
f"max_order={self.max_order}, "
|
||||
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
|
||||
f"skip_predict_identifiers={self.skip_predict_identifiers}, "
|
||||
f"cache_identifiers={self.cache_identifiers}, "
|
||||
f"use_lite_mode={self.use_lite_mode})"
|
||||
)
|
||||
|
||||
|
||||
class TaylorSeerState:
|
||||
def __init__(
|
||||
self,
|
||||
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
max_order: int = 1,
|
||||
is_inactive: bool = False,
|
||||
):
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.max_order = max_order
|
||||
self.is_inactive = is_inactive
|
||||
|
||||
self.module_dtypes: Tuple[torch.dtype, ...] = ()
|
||||
self.last_update_step: Optional[int] = None
|
||||
self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
|
||||
self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
|
||||
self.device: Optional[torch.device] = None
|
||||
self.current_step: int = -1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = -1
|
||||
self.last_update_step = None
|
||||
self.taylor_factors = {}
|
||||
self.inactive_shapes = None
|
||||
self.device = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
outputs: Tuple[torch.Tensor, ...],
|
||||
) -> None:
|
||||
self.module_dtypes = tuple(output.dtype for output in outputs)
|
||||
self.device = outputs[0].device
|
||||
|
||||
if self.is_inactive:
|
||||
self.inactive_shapes = tuple(output.shape for output in outputs)
|
||||
else:
|
||||
for i, features in enumerate(outputs):
|
||||
new_factors: Dict[int, torch.Tensor] = {0: features}
|
||||
is_first_update = self.last_update_step is None
|
||||
if not is_first_update:
|
||||
delta_step = self.current_step - self.last_update_step
|
||||
if delta_step == 0:
|
||||
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
|
||||
|
||||
# Recursive divided differences up to max_order
|
||||
prev_factors = self.taylor_factors.get(i, {})
|
||||
for j in range(self.max_order):
|
||||
prev = prev_factors.get(j)
|
||||
if prev is None:
|
||||
break
|
||||
new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step
|
||||
self.taylor_factors[i] = {
|
||||
order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()
|
||||
}
|
||||
|
||||
self.last_update_step = self.current_step
|
||||
|
||||
@torch.compiler.disable
|
||||
def predict(self) -> List[torch.Tensor]:
|
||||
if self.last_update_step is None:
|
||||
raise ValueError("Cannot predict without prior initialization/update.")
|
||||
|
||||
step_offset = self.current_step - self.last_update_step
|
||||
|
||||
outputs = []
|
||||
if self.is_inactive:
|
||||
if self.inactive_shapes is None:
|
||||
raise ValueError("Inactive shapes not set during prediction.")
|
||||
for i in range(len(self.module_dtypes)):
|
||||
outputs.append(
|
||||
torch.zeros(
|
||||
self.inactive_shapes[i],
|
||||
dtype=self.module_dtypes[i],
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if not self.taylor_factors:
|
||||
raise ValueError("Taylor factors empty during prediction.")
|
||||
num_outputs = len(self.taylor_factors)
|
||||
num_orders = len(self.taylor_factors[0])
|
||||
for i in range(num_outputs):
|
||||
output_dtype = self.module_dtypes[i]
|
||||
taylor_factors = self.taylor_factors[i]
|
||||
output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
|
||||
for order in range(num_orders):
|
||||
coeff = (step_offset**order) / math.factorial(order)
|
||||
factor = taylor_factors[order]
|
||||
output = output + factor.to(output_dtype) * coeff
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
|
||||
class TaylorSeerCacheHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_interval: int,
|
||||
disable_cache_before_step: int,
|
||||
taylor_factors_dtype: torch.dtype,
|
||||
state_manager: StateManager,
|
||||
disable_cache_after_step: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.cache_interval = cache_interval
|
||||
self.disable_cache_before_step = disable_cache_before_step
|
||||
self.disable_cache_after_step = disable_cache_after_step
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.state_manager = state_manager
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
return module
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> None:
|
||||
"""
|
||||
Reset state between sampling runs.
|
||||
"""
|
||||
self.state_manager.reset()
|
||||
|
||||
@torch.compiler.disable
|
||||
def _measure_should_compute(self) -> bool:
|
||||
state: TaylorSeerState = self.state_manager.get_state()
|
||||
state.current_step += 1
|
||||
current_step = state.current_step
|
||||
is_warmup_phase = current_step < self.disable_cache_before_step
|
||||
is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
|
||||
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
|
||||
should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
|
||||
return should_compute, state
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
should_compute, state = self._measure_should_compute()
|
||||
if should_compute:
|
||||
outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
|
||||
state.update(wrapped_outputs)
|
||||
return outputs
|
||||
|
||||
outputs_list = state.predict()
|
||||
return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list)
|
||||
|
||||
|
||||
def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Resolve effective inactive and active pattern lists from config + templates.
|
||||
"""
|
||||
|
||||
inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None
|
||||
active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
|
||||
|
||||
return inactive_patterns or [], active_patterns or []
|
||||
|
||||
|
||||
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
|
||||
"""
|
||||
Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
|
||||
|
||||
This function hooks selected modules in the model to enable caching or skipping based on the provided
|
||||
configuration, reducing redundant computations in diffusion denoising loops.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): The model subtree to apply the hooks to.
|
||||
config (TaylorSeerCacheConfig): Configuration for the cache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
|
||||
>>> pipe = FluxPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-dev",
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = TaylorSeerCacheConfig(
|
||||
... cache_interval=5,
|
||||
... max_order=1,
|
||||
... disable_cache_before_step=3,
|
||||
... taylor_factors_dtype=torch.float32,
|
||||
... )
|
||||
>>> pipe.transformer.enable_cache(config)
|
||||
```
|
||||
"""
|
||||
inactive_patterns, active_patterns = _resolve_patterns(config)
|
||||
|
||||
active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
if config.use_lite_mode:
|
||||
logger.info("Using TaylorSeer Lite variant for cache.")
|
||||
active_patterns = _PROJ_OUT_IDENTIFIERS
|
||||
inactive_patterns = _BLOCK_IDENTIFIERS
|
||||
if config.skip_predict_identifiers or config.cache_identifiers:
|
||||
logger.warning("Lite mode overrides user patterns.")
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns)
|
||||
matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns)
|
||||
if not (matches_inactive or matches_active):
|
||||
continue
|
||||
_apply_taylorseer_cache_hook(
|
||||
module=submodule,
|
||||
config=config,
|
||||
is_inactive=matches_inactive,
|
||||
)
|
||||
|
||||
|
||||
def _apply_taylorseer_cache_hook(
|
||||
module: nn.Module,
|
||||
config: TaylorSeerCacheConfig,
|
||||
is_inactive: bool,
|
||||
):
|
||||
"""
|
||||
Registers the TaylorSeer hook on the specified nn.Module.
|
||||
|
||||
Args:
|
||||
name: Name of the module.
|
||||
module: The nn.Module to be hooked.
|
||||
config: Cache configuration.
|
||||
is_inactive: Whether this module should operate in "inactive" mode.
|
||||
"""
|
||||
state_manager = StateManager(
|
||||
TaylorSeerState,
|
||||
init_kwargs={
|
||||
"taylor_factors_dtype": config.taylor_factors_dtype,
|
||||
"max_order": config.max_order,
|
||||
"is_inactive": is_inactive,
|
||||
},
|
||||
)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
hook = TaylorSeerCacheHook(
|
||||
cache_interval=config.cache_interval,
|
||||
disable_cache_before_step=config.disable_cache_before_step,
|
||||
taylor_factors_dtype=config.taylor_factors_dtype,
|
||||
disable_cache_after_step=config.disable_cache_after_step,
|
||||
state_manager=state_manager,
|
||||
)
|
||||
|
||||
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
|
||||
@@ -67,9 +67,11 @@ class CacheMixin:
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
@@ -83,16 +85,25 @@ class CacheMixin:
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
apply_taylorseer_cache(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -107,6 +118,8 @@ class CacheMixin:
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
|
||||
@@ -257,6 +257,21 @@ class SmoothedEnergyGuidanceConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TaylorSeerCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
@@ -273,6 +288,10 @@ def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
def apply_taylorseer_cache(*args, **kwargs):
|
||||
requires_backends(apply_taylorseer_cache, ["torch"])
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
check_qkv_fused_layers_exist,
|
||||
)
|
||||
|
||||
@@ -39,6 +40,7 @@ class FluxPipelineFastTests(
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
|
||||
@@ -33,6 +33,7 @@ from ..test_pipelines_common import (
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
to_np,
|
||||
)
|
||||
|
||||
@@ -45,6 +46,7 @@ class HunyuanVideoPipelineFastTests(
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
@@ -2924,6 +2925,57 @@ class FirstBlockCacheTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class TaylorSeerCacheTesterMixin:
|
||||
taylorseer_cache_config = TaylorSeerCacheConfig(
|
||||
cache_interval=5,
|
||||
disable_cache_before_step=10,
|
||||
max_order=1,
|
||||
taylor_factors_dtype=torch.bfloat16,
|
||||
use_lite_mode=True,
|
||||
)
|
||||
|
||||
def test_taylorseer_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 50
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without TaylorSeerCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with TaylorSeerCache enabled
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.taylorseer_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with TaylorSeerCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
|
||||
"TaylorSeerCache outputs should not differ much."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
|
||||
"Outputs from normal inference and after disabling cache should not differ."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user