Compare commits

...

2 Commits

Author SHA1 Message Date
Aryan
a64118206a update 2025-01-02 15:35:58 +01:00
Aryan
13d5af7649 init 2025-01-02 13:11:44 +01:00
2 changed files with 479 additions and 0 deletions

View File

@@ -0,0 +1,227 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Dict, Tuple
import torch
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
"""
_is_stateful = False
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
return output
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.
Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
if self._is_stateful:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
class SequentialHook(ModelHook):
r"""A hook that can contain several hooks and iterates through them at each event."""
def __init__(self, *hooks):
self.hooks = hooks
def init_hook(self, module):
for hook in self.hooks:
module = hook.init_hook(module)
return module
def pre_forward(self, module, *args, **kwargs):
for hook in self.hooks:
args, kwargs = hook.pre_forward(module, *args, **kwargs)
return args, kwargs
def post_forward(self, module, output):
for hook in self.hooks:
output = hook.post_forward(module, output)
return output
def detach_hook(self, module):
for hook in self.hooks:
module = hook.detach_hook(module)
return module
def reset_state(self, module):
for hook in self.hooks:
if hook._is_stateful:
hook.reset_state(module)
return module
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:
r"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
<Tip warning={true}>
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
</Tip>
Args:
module (`torch.nn.Module`):
The module to attach a hook to.
hook (`ModelHook`):
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
Returns:
`torch.nn.Module`:
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
"""
original_hook = hook
if append and getattr(module, "_diffusers_hook", None) is not None:
old_hook = module._diffusers_hook
remove_hook_from_module(module)
hook = SequentialHook(old_hook, hook)
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
# If we already put some hook on this module, we replace it with the new one.
old_forward = module._old_forward
else:
old_forward = module.forward
module._old_forward = old_forward
module = hook.init_hook(module)
module._diffusers_hook = hook
if hasattr(original_hook, "new_forward"):
new_forward = original_hook.new_forward
else:
def new_forward(module, *args, **kwargs):
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
output = module._old_forward(*args, **kwargs)
return module._diffusers_hook.post_forward(module, output)
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
else:
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
return module
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
"""
Removes any hook attached to a module via `add_hook_to_module`.
Args:
module (`torch.nn.Module`):
The module to attach a hook to.
recurse (`bool`, defaults to `False`):
Whether to remove the hooks recursively
Returns:
`torch.nn.Module`:
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
"""
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.detach_hook(module)
delattr(module, "_diffusers_hook")
if hasattr(module, "_old_forward"):
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = module._old_forward
else:
module.forward = module._old_forward
delattr(module, "_old_forward")
if recurse:
for child in module.children():
remove_hook_from_module(child, recurse)
return module
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
"""
Resets the state of all stateful hooks attached to a module.
Args:
module (`torch.nn.Module`):
The module to reset the stateful hooks from.
"""
if hasattr(module, "_diffusers_hook") and (
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
):
module._diffusers_hook.reset_state(module)
if recurse:
for child in module.children():
reset_stateful_hooks(child, recurse)

View File

@@ -0,0 +1,252 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
from ..models import (
FluxTransformer2DModel,
HunyuanVideoTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
)
from ..models.hooks import ModelHook, add_hook_to_module
from ..utils import logging
from .pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Source: https://github.com/ali-vilab/TeaCache
# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
# fmt: off
_MODEL_TO_POLY_COEFFICIENTS = {
FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
}
# fmt: on
_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
FluxTransformer2DModel: 0.25,
HunyuanVideoTransformer3DModel: 0.1,
LTXVideoTransformer3DModel: 0.05,
LuminaNextDiT2DModel: 0.2,
MochiTransformer3DModel: 0.06,
}
_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
FluxTransformer2DModel: "transformer_blocks.0.norm1",
}
_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
FluxTransformer2DModel: "norm_out",
}
_DEFAULT_SKIP_LAYER_IDENTIFIERS = [
"blocks",
"transformer_blocks",
"single_transformer_blocks",
"temporal_transformer_blocks",
]
@dataclass
class TeaCacheConfig:
l1_threshold: Optional[float] = None
skip_layer_identifiers: List[str] = _DEFAULT_SKIP_LAYER_IDENTIFIERS
_polynomial_coefficients: Optional[List[float]] = None
class TeaCacheDenoiserState:
def __init__(self):
self.iteration: int = 0
self.accumulated_l1_difference: float = 0.0
self.timestep_modulated_cache: torch.Tensor = None
self.residual_cache: torch.Tensor = None
self.should_skip_blocks: bool = False
def reset(self):
self.iteration = 0
self.accumulated_l1_difference = 0.0
self.timestep_modulated_cache = None
self.residual_cache = None
def apply_teacache(
pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
) -> None:
r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.
Args:
TODO
"""
if config is None:
logger.warning("No TeaCacheConfig provided. Using default configuration.")
config = TeaCacheConfig()
if denoiser is None:
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())):
if config.l1_threshold is None:
logger.info(
f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
f"For higher speedup, increase the threshold."
)
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
if config.timestep_modulated_layer_identifier is None:
logger.info(
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
)
config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
if config._polynomial_coefficients is None:
logger.info(
f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
)
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
else:
if config.l1_threshold is None:
raise ValueError(
f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
)
if config.timestep_modulated_layer_identifier is None:
raise ValueError(
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
)
if config._polynomial_coefficients is None:
raise ValueError(
f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
)
timestep_modulated_layer_matches = list(
{
module
for name, module in denoiser.named_modules()
if re.match(config.timestep_modulated_layer_identifier, name)
}
)
if len(timestep_modulated_layer_matches) == 0:
raise ValueError(
f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
)
if len(timestep_modulated_layer_matches) > 1:
logger.warning(
f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
f"{config.timestep_modulated_layer_identifier}. Using the first match."
)
denoiser_state = TeaCacheDenoiserState()
timestep_modulated_layer = timestep_modulated_layer_matches[0]
hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
add_hook_to_module(timestep_modulated_layer, hook, append=True)
skip_layer_identifiers = config.skip_layer_identifiers
skip_layer_matches = list(
{
module
for name, module in denoiser.named_modules()
if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
}
)
for skip_layer in skip_layer_matches:
hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
add_hook_to_module(skip_layer, hook, append=True)
class TimestepModulatedOutputCacheHook(ModelHook):
# The denoiser hook will reset its state, so we don't have to handle it here
_is_stateful = False
def __init__(
self,
denoiser_state: TeaCacheDenoiserState,
l1_threshold: float,
polynomial_coefficients: List[float],
) -> None:
self.denoiser_state = denoiser_state
self.l1_threshold = l1_threshold
# TODO(aryan): implement torch equivalent
self.rescale_fn = np.poly1d(polynomial_coefficients)
def post_forward(self, module, output):
if isinstance(output, tuple):
# This assumes that the first element of the output tuple is the timestep modulated noise output.
# For Diffusers models, this is true. For models outside diffusers, users will have to ensure
# that the first element of the output tuple is the timestep modulated noise output (seems to be
# the case for most research model implementations).
timestep_modulated_noise = output[0]
elif torch.is_tensor(output):
timestep_modulated_noise = output
else:
raise ValueError(
f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
f"modulated noise output as the first element."
)
if self.denoiser_state.timestep_modulated_cache is not None:
l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff
if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
self.denoiser_state.should_skip_blocks = True
self.denoiser_state.accumulated_l1_difference = 0.0
else:
self.denoiser_state.should_skip_blocks = False
self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
return output
class DenoiserStateBasedSkipLayerHook(ModelHook):
_is_stateful = False
def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
self.denoiser_state = denoiser_state
def new_forward(self, module, *args, **kwargs):
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
if not self.denoiser_state.should_skip_blocks:
output = module._old_forward(*args, **kwargs)
else:
# Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
# Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
# anywhere if self.denoiser_state.should_skip_blocks is True.
output = (None, None)
return module._diffusers_hook.post_forward(module, output)