mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
2 Commits
remove-exp
...
teacache
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a64118206a | ||
|
|
13d5af7649 |
227
src/diffusers/models/hooks.py
Normal file
227
src/diffusers/models/hooks.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when a model is initialized.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module attached to this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Hook that is executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`):
|
||||
The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`):
|
||||
The keyword arguments passed to the module.
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
||||
A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
||||
r"""
|
||||
Hook that is executed just after the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass been executed just before this event.
|
||||
output (`Any`):
|
||||
The output of the module.
|
||||
Returns:
|
||||
`Any`: The processed `output`.
|
||||
"""
|
||||
return output
|
||||
|
||||
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when the hook is detached from a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module detached from this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self._is_stateful:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
|
||||
class SequentialHook(ModelHook):
|
||||
r"""A hook that can contain several hooks and iterates through them at each event."""
|
||||
|
||||
def __init__(self, *hooks):
|
||||
self.hooks = hooks
|
||||
|
||||
def init_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.init_hook(module)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
for hook in self.hooks:
|
||||
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
for hook in self.hooks:
|
||||
output = hook.post_forward(module, output)
|
||||
return output
|
||||
|
||||
def detach_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.detach_hook(module)
|
||||
return module
|
||||
|
||||
def reset_state(self, module):
|
||||
for hook in self.hooks:
|
||||
if hook._is_stateful:
|
||||
hook.reset_state(module)
|
||||
return module
|
||||
|
||||
|
||||
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:
|
||||
r"""
|
||||
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
||||
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
|
||||
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach a hook to.
|
||||
hook (`ModelHook`):
|
||||
The hook to attach.
|
||||
append (`bool`, *optional*, defaults to `False`):
|
||||
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
||||
Returns:
|
||||
`torch.nn.Module`:
|
||||
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
|
||||
"""
|
||||
original_hook = hook
|
||||
|
||||
if append and getattr(module, "_diffusers_hook", None) is not None:
|
||||
old_hook = module._diffusers_hook
|
||||
remove_hook_from_module(module)
|
||||
hook = SequentialHook(old_hook, hook)
|
||||
|
||||
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
|
||||
# If we already put some hook on this module, we replace it with the new one.
|
||||
old_forward = module._old_forward
|
||||
else:
|
||||
old_forward = module.forward
|
||||
module._old_forward = old_forward
|
||||
|
||||
module = hook.init_hook(module)
|
||||
module._diffusers_hook = hook
|
||||
|
||||
if hasattr(original_hook, "new_forward"):
|
||||
new_forward = original_hook.new_forward
|
||||
else:
|
||||
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
|
||||
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
||||
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
||||
if "GraphModuleImpl" in str(type(module)):
|
||||
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
||||
else:
|
||||
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
Removes any hook attached to a module via `add_hook_to_module`.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach a hook to.
|
||||
recurse (`bool`, defaults to `False`):
|
||||
Whether to remove the hooks recursively
|
||||
Returns:
|
||||
`torch.nn.Module`:
|
||||
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
|
||||
"""
|
||||
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.detach_hook(module)
|
||||
delattr(module, "_diffusers_hook")
|
||||
|
||||
if hasattr(module, "_old_forward"):
|
||||
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
||||
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
||||
if "GraphModuleImpl" in str(type(module)):
|
||||
module.__class__.forward = module._old_forward
|
||||
else:
|
||||
module.forward = module._old_forward
|
||||
delattr(module, "_old_forward")
|
||||
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
remove_hook_from_module(child, recurse)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
|
||||
"""
|
||||
Resets the state of all stateful hooks attached to a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to reset the stateful hooks from.
|
||||
"""
|
||||
if hasattr(module, "_diffusers_hook") and (
|
||||
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
|
||||
):
|
||||
module._diffusers_hook.reset_state(module)
|
||||
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
reset_stateful_hooks(child, recurse)
|
||||
252
src/diffusers/pipelines/teacache_utils.py
Normal file
252
src/diffusers/pipelines/teacache_utils.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models import (
|
||||
FluxTransformer2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
)
|
||||
from ..models.hooks import ModelHook, add_hook_to_module
|
||||
from ..utils import logging
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Source: https://github.com/ali-vilab/TeaCache
|
||||
# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
|
||||
# fmt: off
|
||||
_MODEL_TO_POLY_COEFFICIENTS = {
|
||||
FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
|
||||
HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
|
||||
LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
|
||||
LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
|
||||
MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
|
||||
FluxTransformer2DModel: 0.25,
|
||||
HunyuanVideoTransformer3DModel: 0.1,
|
||||
LTXVideoTransformer3DModel: 0.05,
|
||||
LuminaNextDiT2DModel: 0.2,
|
||||
MochiTransformer3DModel: 0.06,
|
||||
}
|
||||
|
||||
_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
|
||||
FluxTransformer2DModel: "transformer_blocks.0.norm1",
|
||||
}
|
||||
|
||||
_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
|
||||
FluxTransformer2DModel: "norm_out",
|
||||
}
|
||||
|
||||
_DEFAULT_SKIP_LAYER_IDENTIFIERS = [
|
||||
"blocks",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"temporal_transformer_blocks",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeaCacheConfig:
|
||||
l1_threshold: Optional[float] = None
|
||||
|
||||
skip_layer_identifiers: List[str] = _DEFAULT_SKIP_LAYER_IDENTIFIERS
|
||||
|
||||
_polynomial_coefficients: Optional[List[float]] = None
|
||||
|
||||
|
||||
class TeaCacheDenoiserState:
|
||||
def __init__(self):
|
||||
self.iteration: int = 0
|
||||
self.accumulated_l1_difference: float = 0.0
|
||||
self.timestep_modulated_cache: torch.Tensor = None
|
||||
self.residual_cache: torch.Tensor = None
|
||||
self.should_skip_blocks: bool = False
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.accumulated_l1_difference = 0.0
|
||||
self.timestep_modulated_cache = None
|
||||
self.residual_cache = None
|
||||
|
||||
|
||||
def apply_teacache(
|
||||
pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
|
||||
) -> None:
|
||||
r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.
|
||||
|
||||
Args:
|
||||
TODO
|
||||
"""
|
||||
|
||||
if config is None:
|
||||
logger.warning("No TeaCacheConfig provided. Using default configuration.")
|
||||
config = TeaCacheConfig()
|
||||
|
||||
if denoiser is None:
|
||||
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
|
||||
|
||||
if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())):
|
||||
if config.l1_threshold is None:
|
||||
logger.info(
|
||||
f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
|
||||
f"For higher speedup, increase the threshold."
|
||||
)
|
||||
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
|
||||
if config.timestep_modulated_layer_identifier is None:
|
||||
logger.info(
|
||||
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
|
||||
)
|
||||
config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
|
||||
if config._polynomial_coefficients is None:
|
||||
logger.info(
|
||||
f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
|
||||
)
|
||||
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
|
||||
else:
|
||||
if config.l1_threshold is None:
|
||||
raise ValueError(
|
||||
f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
|
||||
f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
|
||||
)
|
||||
if config.timestep_modulated_layer_identifier is None:
|
||||
raise ValueError(
|
||||
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
|
||||
f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
|
||||
)
|
||||
if config._polynomial_coefficients is None:
|
||||
raise ValueError(
|
||||
f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
|
||||
f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
|
||||
f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
|
||||
)
|
||||
|
||||
timestep_modulated_layer_matches = list(
|
||||
{
|
||||
module
|
||||
for name, module in denoiser.named_modules()
|
||||
if re.match(config.timestep_modulated_layer_identifier, name)
|
||||
}
|
||||
)
|
||||
|
||||
if len(timestep_modulated_layer_matches) == 0:
|
||||
raise ValueError(
|
||||
f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
|
||||
f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
|
||||
)
|
||||
if len(timestep_modulated_layer_matches) > 1:
|
||||
logger.warning(
|
||||
f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
|
||||
f"{config.timestep_modulated_layer_identifier}. Using the first match."
|
||||
)
|
||||
|
||||
denoiser_state = TeaCacheDenoiserState()
|
||||
|
||||
timestep_modulated_layer = timestep_modulated_layer_matches[0]
|
||||
hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
|
||||
add_hook_to_module(timestep_modulated_layer, hook, append=True)
|
||||
|
||||
skip_layer_identifiers = config.skip_layer_identifiers
|
||||
skip_layer_matches = list(
|
||||
{
|
||||
module
|
||||
for name, module in denoiser.named_modules()
|
||||
if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
|
||||
}
|
||||
)
|
||||
|
||||
for skip_layer in skip_layer_matches:
|
||||
hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
|
||||
add_hook_to_module(skip_layer, hook, append=True)
|
||||
|
||||
|
||||
class TimestepModulatedOutputCacheHook(ModelHook):
|
||||
# The denoiser hook will reset its state, so we don't have to handle it here
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
denoiser_state: TeaCacheDenoiserState,
|
||||
l1_threshold: float,
|
||||
polynomial_coefficients: List[float],
|
||||
) -> None:
|
||||
self.denoiser_state = denoiser_state
|
||||
self.l1_threshold = l1_threshold
|
||||
# TODO(aryan): implement torch equivalent
|
||||
self.rescale_fn = np.poly1d(polynomial_coefficients)
|
||||
|
||||
def post_forward(self, module, output):
|
||||
if isinstance(output, tuple):
|
||||
# This assumes that the first element of the output tuple is the timestep modulated noise output.
|
||||
# For Diffusers models, this is true. For models outside diffusers, users will have to ensure
|
||||
# that the first element of the output tuple is the timestep modulated noise output (seems to be
|
||||
# the case for most research model implementations).
|
||||
timestep_modulated_noise = output[0]
|
||||
elif torch.is_tensor(output):
|
||||
timestep_modulated_noise = output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
|
||||
f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
|
||||
f"modulated noise output as the first element."
|
||||
)
|
||||
|
||||
if self.denoiser_state.timestep_modulated_cache is not None:
|
||||
l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
|
||||
normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
|
||||
rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
|
||||
self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff
|
||||
|
||||
if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
|
||||
self.denoiser_state.should_skip_blocks = True
|
||||
self.denoiser_state.accumulated_l1_difference = 0.0
|
||||
else:
|
||||
self.denoiser_state.should_skip_blocks = False
|
||||
|
||||
self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
|
||||
return output
|
||||
|
||||
|
||||
class DenoiserStateBasedSkipLayerHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
|
||||
self.denoiser_state = denoiser_state
|
||||
|
||||
def new_forward(self, module, *args, **kwargs):
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
|
||||
if not self.denoiser_state.should_skip_blocks:
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
else:
|
||||
# Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
|
||||
# Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
|
||||
# anywhere if self.denoiser_state.should_skip_blocks is True.
|
||||
output = (None, None)
|
||||
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
Reference in New Issue
Block a user