Compare commits

...

31 Commits

Author SHA1 Message Date
Aryan
80a7854540 refactor 2025-05-20 23:14:58 +02:00
Aryan
5ac3d644d2 Merge branch 'main' into integrations/first-block-cache-2 2025-05-20 21:36:16 +02:00
Aryan
b8317da20f remove central registry based on review 2025-05-16 12:53:36 +02:00
Aryan
a5fe2bd4fd remove unused function 2025-05-16 12:37:01 +02:00
Aryan
153cf0c393 fix 2025-05-16 12:33:52 +02:00
Aryan
495fddb8ae support mochi 2025-05-15 21:42:19 +02:00
Aryan
367fdef96d support cogvideox 2025-05-15 21:41:45 +02:00
Aryan
82fa9df1c7 Merge branch 'main' into integrations/first-block-cache-2 2025-05-15 21:38:57 +02:00
Aryan
fb229b54bb metadata registration with decorators instead of centralized 2025-05-14 14:19:48 +02:00
Aryan
0a44380a36 address review comments 2025-05-14 12:14:24 +02:00
Aryan
2ed59c178d Merge branch 'main' into integrations/first-block-cache-2 2025-05-14 11:48:12 +02:00
Aryan
1f3e02f4da Merge branch 'main' into integrations/first-block-cache-2 2025-04-30 04:00:24 +05:30
Aryan
38a603939e Merge branch 'main' into integrations/first-block-cache-2 2025-04-22 07:00:49 +05:30
Aryan
169bb0df9c cache context refacotr; address review pt. 3 2025-04-16 14:24:08 +02:00
Aryan
f731664773 address review comments pt. 2 2025-04-16 14:03:01 +02:00
Aryan
3dde07a647 address review comments pt. 1 2025-04-16 13:36:35 +02:00
Aryan
701cf86654 Merge branch 'main' into integrations/first-block-cache-2 2025-04-16 13:34:26 +02:00
Aryan
ca715a9771 Update src/diffusers/hooks/hooks.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-04-14 18:04:13 +05:30
Aryan
aa8e328328 Merge branch 'main' into integrations/first-block-cache-2 2025-04-05 00:24:21 +02:00
Aryan
ff5f2ee505 Update docs/source/en/api/cache.md 2025-04-05 03:53:34 +05:30
Aryan
46619ea717 update 2025-04-05 00:23:10 +02:00
Aryan
c76e1cc17e update 2025-04-02 21:52:33 +02:00
Aryan
315e357a18 Merge branch 'main' into integrations/first-block-cache-2 2025-04-02 01:21:22 +02:00
Aryan
1f33ca276d support flux, ltx i2v, ltx condition 2025-04-02 01:21:09 +02:00
Aryan
41b0c473d2 fix controlnet flux 2025-04-02 01:20:53 +02:00
Aryan
0e232ac8c0 fix hs residual bug for single return outputs; support ltx 2025-04-02 00:38:11 +02:00
Aryan
2557238b4d cache context for different batches of data 2025-04-01 19:40:23 +02:00
Aryan
d71fe55895 update 2025-04-01 17:06:45 +02:00
Aryan
7ab424a15a remove debug logs 2025-04-01 01:39:00 +02:00
Aryan
dd69b41834 modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) 2025-04-01 01:28:09 +02:00
Aryan
406b1062f8 update 2025-03-31 04:27:35 +02:00
37 changed files with 829 additions and 211 deletions

View File

@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
pipe.transformer.enable_cache(config)
```
## First Block Cache
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
```python
import torch
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
# Smaller values between 0.02-0.20 are recommended based on the model being used. The default value is 0.05.
config = FirstBlockCacheConfig(threshold=0.07)
pipe.transformer.enable_cache(config)
```
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast

View File

@@ -133,9 +133,11 @@ else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_pyramid_attention_broadcast",
]
)
@@ -740,9 +742,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
from .models import (

View File

@@ -1,8 +1,23 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

View File

@@ -0,0 +1,30 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)

View File

@@ -0,0 +1,221 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
but might lead to poorer generation quality. A lower threshold may not result in significant generation
speedup. The threshold is compared against the absmean difference of the residuals between the current and
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
is skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output_hidden_states: torch.Tensor = None
self.head_block_output_encoder_hidden_states: torch.Tensor = None
self.head_block_residual_hidden_states: torch.Tensor = None
self.tail_block_residual_hidden_states: torch.Tensor = None
self.tail_block_residual_encoder_hidden_states: torch.Tensor = None
self.should_compute: bool = True
def reset(self):
self.head_block_output_hidden_states = None
self.head_block_output_encoder_hidden_states = None
self.head_block_residual_hidden_states = None
self.tail_block_residual_hidden_states = None
self.tail_block_residual_encoder_hidden_states = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, threshold: float):
self.state_manager = state_manager
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
raise ValueError(
f"Module {unwrapped_module} does not have any registered metadata. "
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
)
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
hidden_states_residual = output.hidden_states - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states = encoder_hidden_states = None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
return_output = output.__class__()
hidden_states = shared_state.tail_block_residual_hidden_states + output.hidden_states
return_output = return_output._replace(hidden_states=hidden_states)
if hasattr(output, "encoder_hidden_states"):
encoder_hidden_states = (
shared_state.tail_block_residual_encoder_hidden_states + output.encoder_hidden_states
)
return_output = return_output._replace(encoder_hidden_states=encoder_hidden_states)
else:
return_output = output
shared_state.head_block_output_hidden_states = output.hidden_states
if hasattr(output, "encoder_hidden_states"):
shared_state.head_block_output_encoder_hidden_states = output.encoder_hidden_states
shared_state.head_block_residual_hidden_states = hidden_states_residual
return return_output
def reset_state(self, module):
self.state_manager.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual_hidden_states is None:
return True
prev_hidden_states_residual = shared_state.head_block_residual_hidden_states
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
self._output_cls = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
raise ValueError(
f"Module {unwrapped_module} does not have any registered metadata. "
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
)
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
original_encoder_hidden_states = None
try:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
except ValueError:
# This is expected for models that don't have use encoder_hidden_states in their forward definition
pass
shared_state = self.state_manager.get_state()
if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self._output_cls is None:
self._output_cls = output.__class__
if self.is_tail:
hidden_states_residual = output.hidden_states - shared_state.head_block_output_hidden_states
if hasattr(output, "encoder_hidden_states"):
encoder_hidden_states_residual = (
output.encoder_hidden_states - shared_state.head_block_output_encoder_hidden_states
)
shared_state.tail_block_residual_hidden_states = hidden_states_residual
shared_state.tail_block_residual_encoder_hidden_states = encoder_hidden_states_residual
return output
assert self._output_cls is not None
return_output = self._output_cls()
return_output = return_output._replace(hidden_states=original_hidden_states)
if hasattr(return_output, "encoder_hidden_states"):
return_output = return_output._replace(encoder_hidden_states=original_encoder_hidden_states)
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Applying FBCBlockHook to '{name}'")
_apply_fbc_block_hook(block, state_manager)
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state_manager, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state_manager, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)

View File

@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class StateManager:
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
self._state_cls = state_cls
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._state_cache = {}
self._current_context = None
def get_state(self):
if self._current_context is None:
raise ValueError("No context is set. Please set a context before retrieving the state.")
if self._current_context not in self._state_cache.keys():
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
return self._state_cache[self._current_context]
def set_context(self, name: str) -> None:
self._current_context = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._current_context = None
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _set_context(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, StateManager):
attr.set_context(name)
return module
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _set_context(self, name: Optional[str] = None) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._set_context(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(name)
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):

View File

@@ -22,6 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .metadata import TransformerBlockMetadata, register_transformer_block
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
@@ -258,6 +259,12 @@ class JointTransformerBlock(nn.Module):
@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry = HookRegistry.check_if_exists_or_initialize(self)
if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def cache_context(self, name: str):
r"""Context manager that provides additional methods for cache management."""
from ..hooks import HookRegistry
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)
yield
registry._set_context(None)

View File

@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()

View File

@@ -0,0 +1,53 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Dict, Type
@dataclass
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier]
if self._cached_parameter_indices is not None:
return args[self._cached_parameter_indices[identifier]]
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self._cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self._cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index]
def register_transformer_block(metadata: TransformerBlockMetadata):
def inner(model_class: Type):
metadata._cls = model_class
model_class._diffusers_transformer_block_metadata = metadata
return model_class
return inner

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import torch
from torch import nn
@@ -34,6 +34,11 @@ from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXBlockOutput(NamedTuple):
hidden_states: torch.Tensor = None
encoder_hidden_states: torch.Tensor = None
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
@@ -122,7 +127,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> CogVideoXBlockOutput:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
@@ -154,7 +159,7 @@ class CogVideoXBlock(nn.Module):
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
return CogVideoXBlockOutput(hidden_states, encoder_hidden_states)
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):

View File

@@ -21,10 +21,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
@@ -453,6 +455,13 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class CogView4TransformerBlock(nn.Module):
def __init__(
self,

View File

@@ -34,6 +34,7 @@ from ..attention_processor import (
)
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -43,6 +44,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
)
)
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
@@ -79,10 +86,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,10 +111,17 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
)
)
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
@@ -508,20 +526,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -531,12 +550,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

View File

@@ -33,6 +33,7 @@ from ..embeddings import (
Timesteps,
get_1d_rotary_pos_embed,
)
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
@@ -310,6 +311,12 @@ class HunyuanVideoConditionEmbedding(nn.Module):
return conditioning, token_replace_emb
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
@@ -489,6 +496,12 @@ class HunyuanVideoRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoSingleTransformerBlock(nn.Module):
def __init__(
self,
@@ -565,6 +578,12 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoTransformerBlock(nn.Module):
def __init__(
self,
@@ -644,6 +663,12 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
def __init__(
self,
@@ -724,6 +749,12 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
def __init__(
self,

View File

@@ -28,6 +28,7 @@ from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -196,6 +197,12 @@ class LTXVideoRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).

View File

@@ -27,6 +27,7 @@ from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, RMSNorm
@@ -116,6 +117,12 @@ class MochiRMSNormZero(nn.Module):
@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class MochiTransformerBlock(nn.Module):
r"""
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -22,10 +22,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
@@ -34,6 +36,11 @@ from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanTransformerBlockOutput(NamedTuple):
hidden_states: torch.Tensor = None
encoder_hidden_states: torch.Tensor = None
class WanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -219,6 +226,8 @@ class WanRotaryPosEmbed(nn.Module):
return freqs
@maybe_allow_in_graph
@register_transformer_block(TransformerBlockMetadata())
class WanTransformerBlock(nn.Module):
def __init__(
self,
@@ -276,7 +285,7 @@ class WanTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
) -> WanTransformerBlockOutput:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
@@ -298,7 +307,7 @@ class WanTransformerBlock(nn.Module):
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
return WanTransformerBlockOutput(hidden_states, encoder_hidden_states)
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
@@ -447,12 +456,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

View File

@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond

View File

@@ -909,32 +909,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:

View File

@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

View File

@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:

View File

@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)

View File

@@ -530,22 +530,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FirstBlockCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
def apply_first_block_cache(*args, **kwargs):
requires_backends(apply_first_block_cache, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

View File

@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).

View File

@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -44,7 +45,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}

View File

@@ -25,6 +25,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
@@ -34,11 +35,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
num_layers=num_layers,
caption_channels=32,
)

View File

@@ -33,13 +33,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
class MochiPipelineFastTests(
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

View File

@@ -33,6 +33,7 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -2632,7 +2633,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
output = run_forward(pipe).flatten().flatten()
output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2739,6 +2740,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
# that will not satisfy the expected properties of the denoiser for caching to be effective
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
return pipe(**inputs)[0]
# Run inference without FirstBlockCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache enabled
pipe = create_pipe()
pipe.transformer.enable_cache(self.first_block_cache_config)
output = run_forward(pipe).flatten()
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
"FirstBlockCache outputs should not differ much."
)
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.