Compare commits

...

36 Commits

Author SHA1 Message Date
sayakpaul
9a91821bd2 fix example usage. 2025-12-05 20:34:09 +07:00
sayakpaul
a9c59b7de6 make fix-copies 2025-12-05 20:31:31 +07:00
Sayak Paul
3bf9f9d281 Merge branch 'main' into feat-taylorseer 2025-12-05 21:25:35 +08:00
toilaluan
5229769a94 update docs 2025-12-05 07:31:57 +00:00
Tran Thanh Luan
d009d451c2 Update src/diffusers/hooks/taylorseer_cache.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-12-05 14:26:25 +07:00
github-actions[bot]
ca24569a2a Apply style fixes 2025-12-04 11:16:52 +00:00
toilaluan
76494ca098 torch compile compatible 2025-12-04 06:12:37 +00:00
toilaluan
4fb3f53b6c rename identifiers, use more expressive taylor predict loop 2025-12-03 10:43:01 +00:00
Tran Thanh Luan
475ec02d8c Remove TaylorSeerCacheTesterMixin from flux2 tests 2025-12-01 09:31:23 +07:00
github-actions[bot]
289146e73e Apply style fixes 2025-11-30 21:40:03 +00:00
toilaluan
e2dae7e432 add tests 2025-11-29 07:21:01 +00:00
github-actions[bot]
716dfe1468 Apply style fixes 2025-11-28 12:52:07 +00:00
Tran Thanh Luan
ddc6164d18 Merge branch 'main' into feat-taylorseer 2025-11-28 19:29:04 +07:00
toilaluan
d06c6bc6c2 fix taylor precision 2025-11-28 08:14:41 +00:00
toilaluan
309ce72140 quality & style 2025-11-28 07:28:44 +00:00
toilaluan
83b62531f8 add docs 2025-11-28 07:23:06 +00:00
toilaluan
24267c76de chores: naming, remove redundancy 2025-11-28 07:23:01 +00:00
Tran Thanh Luan
656c7bc501 Merge branch 'main' into feat-taylorseer 2025-11-26 11:12:25 +07:00
Tran Thanh Luan
a644417835 Merge branch 'huggingface:main' into feat-taylorseer 2025-11-26 10:27:22 +07:00
toilaluan
2be31f856e fix format & doc 2025-11-25 06:02:13 +00:00
Tran Thanh Luan
b3217139f5 Merge branch 'main' into feat-taylorseer 2025-11-25 12:31:03 +07:00
toilaluan
a8ea383044 refractor to use state manager 2025-11-25 05:28:00 +00:00
toilaluan
9083e1eba5 update to handle multple calls per timestep 2025-11-20 09:54:29 +00:00
Tran Thanh Luan
05f61a9cc3 Merge branch 'main' into feat-taylorseer 2025-11-20 14:21:11 +07:00
toilaluan
d929ab28a7 apply ruff 2025-11-17 13:24:20 +07:00
Tran Thanh Luan
9290b5895f Merge branch 'main' into feat-taylorseer 2025-11-17 13:21:41 +07:00
toilaluan
acfebfa3f3 update docs 2025-11-17 13:21:01 +07:00
toilaluan
7238d40dd9 add stop_predicts (cooldown) 2025-11-16 05:09:44 +00:00
toilaluan
51b4318a3e allow special cache ids only 2025-11-15 05:13:33 +00:00
toilaluan
7b4ad2de63 add configurable cache, skip compute module 2025-11-14 09:09:46 +00:00
toilaluan
1099e493e6 refractor, add docs 2025-11-14 07:00:12 +00:00
toilaluan
0602044da7 still update in warmup steps 2025-11-13 17:03:35 +00:00
toilaluan
8f80072618 use logger for printing, add warmup feature 2025-11-13 13:11:29 +00:00
toilaluan
8f495b607f make compatible with any tuple size returned 2025-11-13 11:37:54 +00:00
Tran Thanh Luan
fe20f9798f Merge branch 'main' into feat-taylorseer 2025-11-13 16:46:49 +07:00
toilaluan
a4bfa451fe init taylor_seer cache 2025-11-13 15:06:36 +07:00
12 changed files with 481 additions and 1 deletions

View File

@@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### TaylorSeerCacheConfig
[[autodoc]] TaylorSeerCacheConfig
[[autodoc]] apply_taylorseer_cache

View File

@@ -66,4 +66,35 @@ config = FasterCacheConfig(
tensor_format="BFCHW",
)
pipeline.transformer.enable_cache(config)
```
## TaylorSeer Cache
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
This caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080).
To enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer:
- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass
- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations
- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`)
```python
import torch
from diffusers import FluxPipeline, TaylorSeerCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
config = TaylorSeerCacheConfig(
cache_interval=5,
max_order=1,
disable_cache_before_step=10,
taylor_factors_dtype=torch.bfloat16,
)
pipe.transformer.enable_cache(config)
```

View File

@@ -169,10 +169,12 @@ else:
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
)
_import_structure["models"].extend(
@@ -899,10 +901,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
from .models import (
AllegroTransformer3DModel,

View File

@@ -25,3 +25,4 @@ if is_torch_available():
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

View File

@@ -0,0 +1,346 @@
import math
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from ..utils import logging
from .hooks import HookRegistry, ModelHook, StateManager
logger = logging.get_logger(__name__)
_TAYLORSEER_CACHE_HOOK = "taylorseer_cache"
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
"^single_transformer_blocks.*attn",
)
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",)
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
@dataclass
class TaylorSeerCacheConfig:
"""
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
Attributes:
cache_interval (`int`, defaults to `5`):
The interval between full computation steps. After a full computation, the cached (predicted) outputs are
reused for this many subsequent denoising steps before refreshing with a new full forward pass.
disable_cache_before_step (`int`, defaults to `3`):
The denoising step index before which caching is disabled, meaning full computation is performed for the
initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During
these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this
step.
disable_cache_after_step (`int`, *optional*, defaults to `None`):
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run
full computations without predictions or state updates, ensuring accuracy in later stages if needed.
max_order (`int`, defaults to `1`):
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide
better approximations but increase computation and memory usage.
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
affect stability; higher precision improves accuracy at the cost of more memory.
skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode,
the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded
shape) during prediction steps to skip computation cheaply.
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
outputs are approximated and cached for reuse.
use_lite_mode (`bool`, *optional*, defaults to `False`):
Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom
`inactive_identifiers` or `active_identifiers`.
Notes:
- Patterns are matched using `re.fullmatch` on the module name.
- If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked.
- If neither is provided, all attention-like modules are hooked by default.
Example of inactive and active usage:
```py
def forward(x):
x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
return x
```
"""
cache_interval: int = 5
disable_cache_before_step: int = 3
disable_cache_after_step: Optional[int] = None
max_order: int = 1
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
skip_predict_identifiers: Optional[List[str]] = None
cache_identifiers: Optional[List[str]] = None
use_lite_mode: bool = False
def __repr__(self) -> str:
return (
"TaylorSeerCacheConfig("
f"cache_interval={self.cache_interval}, "
f"disable_cache_before_step={self.disable_cache_before_step}, "
f"disable_cache_after_step={self.disable_cache_after_step}, "
f"max_order={self.max_order}, "
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
f"skip_predict_identifiers={self.skip_predict_identifiers}, "
f"cache_identifiers={self.cache_identifiers}, "
f"use_lite_mode={self.use_lite_mode})"
)
class TaylorSeerState:
def __init__(
self,
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
max_order: int = 1,
is_inactive: bool = False,
):
self.taylor_factors_dtype = taylor_factors_dtype
self.max_order = max_order
self.is_inactive = is_inactive
self.module_dtypes: Tuple[torch.dtype, ...] = ()
self.last_update_step: Optional[int] = None
self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
self.device: Optional[torch.device] = None
self.current_step: int = -1
def reset(self) -> None:
self.current_step = -1
self.last_update_step = None
self.taylor_factors = {}
self.inactive_shapes = None
self.device = None
def update(
self,
outputs: Tuple[torch.Tensor, ...],
) -> None:
self.module_dtypes = tuple(output.dtype for output in outputs)
self.device = outputs[0].device
if self.is_inactive:
self.inactive_shapes = tuple(output.shape for output in outputs)
else:
for i, features in enumerate(outputs):
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
if not is_first_update:
delta_step = self.current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
# Recursive divided differences up to max_order
prev_factors = self.taylor_factors.get(i, {})
for j in range(self.max_order):
prev = prev_factors.get(j)
if prev is None:
break
new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step
self.taylor_factors[i] = {
order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()
}
self.last_update_step = self.current_step
@torch.compiler.disable
def predict(self) -> List[torch.Tensor]:
if self.last_update_step is None:
raise ValueError("Cannot predict without prior initialization/update.")
step_offset = self.current_step - self.last_update_step
outputs = []
if self.is_inactive:
if self.inactive_shapes is None:
raise ValueError("Inactive shapes not set during prediction.")
for i in range(len(self.module_dtypes)):
outputs.append(
torch.zeros(
self.inactive_shapes[i],
dtype=self.module_dtypes[i],
device=self.device,
)
)
else:
if not self.taylor_factors:
raise ValueError("Taylor factors empty during prediction.")
num_outputs = len(self.taylor_factors)
num_orders = len(self.taylor_factors[0])
for i in range(num_outputs):
output_dtype = self.module_dtypes[i]
taylor_factors = self.taylor_factors[i]
output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
for order in range(num_orders):
coeff = (step_offset**order) / math.factorial(order)
factor = taylor_factors[order]
output = output + factor.to(output_dtype) * coeff
outputs.append(output)
return outputs
class TaylorSeerCacheHook(ModelHook):
_is_stateful = True
def __init__(
self,
cache_interval: int,
disable_cache_before_step: int,
taylor_factors_dtype: torch.dtype,
state_manager: StateManager,
disable_cache_after_step: Optional[int] = None,
):
super().__init__()
self.cache_interval = cache_interval
self.disable_cache_before_step = disable_cache_before_step
self.disable_cache_after_step = disable_cache_after_step
self.taylor_factors_dtype = taylor_factors_dtype
self.state_manager = state_manager
def initialize_hook(self, module: torch.nn.Module):
return module
def reset_state(self, module: torch.nn.Module) -> None:
"""
Reset state between sampling runs.
"""
self.state_manager.reset()
@torch.compiler.disable
def _measure_should_compute(self) -> bool:
state: TaylorSeerState = self.state_manager.get_state()
state.current_step += 1
current_step = state.current_step
is_warmup_phase = current_step < self.disable_cache_before_step
is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
return should_compute, state
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
should_compute, state = self._measure_should_compute()
if should_compute:
outputs = self.fn_ref.original_forward(*args, **kwargs)
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
state.update(wrapped_outputs)
return outputs
outputs_list = state.predict()
return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list)
def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
"""
Resolve effective inactive and active pattern lists from config + templates.
"""
inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None
active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
return inactive_patterns or [], active_patterns or []
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
"""
Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
This function hooks selected modules in the model to enable caching or skipping based on the provided
configuration, reducing redundant computations in diffusion denoising loops.
Args:
module (torch.nn.Module): The model subtree to apply the hooks to.
config (TaylorSeerCacheConfig): Configuration for the cache.
Example:
```python
>>> import torch
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig
>>> pipe = FluxPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-dev",
... torch_dtype=torch.bfloat16,
... )
>>> pipe.to("cuda")
>>> config = TaylorSeerCacheConfig(
... cache_interval=5,
... max_order=1,
... disable_cache_before_step=3,
... taylor_factors_dtype=torch.float32,
... )
>>> pipe.transformer.enable_cache(config)
```
"""
inactive_patterns, active_patterns = _resolve_patterns(config)
active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
if config.use_lite_mode:
logger.info("Using TaylorSeer Lite variant for cache.")
active_patterns = _PROJ_OUT_IDENTIFIERS
inactive_patterns = _BLOCK_IDENTIFIERS
if config.skip_predict_identifiers or config.cache_identifiers:
logger.warning("Lite mode overrides user patterns.")
for name, submodule in module.named_modules():
matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns)
matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns)
if not (matches_inactive or matches_active):
continue
_apply_taylorseer_cache_hook(
module=submodule,
config=config,
is_inactive=matches_inactive,
)
def _apply_taylorseer_cache_hook(
module: nn.Module,
config: TaylorSeerCacheConfig,
is_inactive: bool,
):
"""
Registers the TaylorSeer hook on the specified nn.Module.
Args:
name: Name of the module.
module: The nn.Module to be hooked.
config: Cache configuration.
is_inactive: Whether this module should operate in "inactive" mode.
"""
state_manager = StateManager(
TaylorSeerState,
init_kwargs={
"taylor_factors_dtype": config.taylor_factors_dtype,
"max_order": config.max_order,
"is_inactive": is_inactive,
},
)
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = TaylorSeerCacheHook(
cache_interval=config.cache_interval,
disable_cache_before_step=config.disable_cache_before_step,
taylor_factors_dtype=config.taylor_factors_dtype,
disable_cache_after_step=config.disable_cache_after_step,
state_manager=state_manager,
)
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)

View File

@@ -67,9 +67,11 @@ class CacheMixin:
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
if self.is_cache_enabled:
@@ -83,16 +85,25 @@ class CacheMixin:
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
apply_taylorseer_cache(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -107,6 +118,8 @@ class CacheMixin:
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")

View File

@@ -257,6 +257,21 @@ class SmoothedEnergyGuidanceConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class TaylorSeerCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
@@ -273,6 +288,10 @@ def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
def apply_taylorseer_cache(*args, **kwargs):
requires_backends(apply_taylorseer_cache, ["torch"])
class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -29,6 +29,7 @@ from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheTesterMixin,
check_qkv_fused_layers_exist,
)
@@ -39,6 +40,7 @@ class FluxPipelineFastTests(
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
TaylorSeerCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline

View File

@@ -19,6 +19,7 @@ from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheTesterMixin,
)
@@ -28,6 +29,7 @@ class FluxKontextPipelineFastTests(
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
TaylorSeerCacheTesterMixin,
):
pipeline_class = FluxKontextPipeline
params = frozenset(

View File

@@ -19,6 +19,7 @@ from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheTesterMixin,
)
@@ -28,6 +29,7 @@ class FluxKontextInpaintPipelineFastTests(
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
TaylorSeerCacheTesterMixin,
):
pipeline_class = FluxKontextInpaintPipeline
params = frozenset(

View File

@@ -33,6 +33,7 @@ from ..test_pipelines_common import (
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheTesterMixin,
to_np,
)
@@ -45,6 +46,7 @@ class HunyuanVideoPipelineFastTests(
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
TaylorSeerCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline

View File

@@ -36,6 +36,7 @@ from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention import AttentionModuleMixin
@@ -2924,6 +2925,57 @@ class FirstBlockCacheTesterMixin:
)
class TaylorSeerCacheTesterMixin:
taylorseer_cache_config = TaylorSeerCacheConfig(
cache_interval=5,
disable_cache_before_step=10,
max_order=1,
taylor_factors_dtype=torch.bfloat16,
use_lite_mode=True,
)
def test_taylorseer_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 50
return pipe(**inputs)[0]
# Run inference without TaylorSeerCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with TaylorSeerCache enabled
pipe = create_pipe()
pipe.transformer.enable_cache(self.taylorseer_cache_config)
output = run_forward(pipe).flatten()
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with TaylorSeerCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
"TaylorSeerCache outputs should not differ much."
)
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.