Compare commits

...

21 Commits

Author SHA1 Message Date
Aryan
a1da7752e5 Merge branch 'main' into feature/guiders 2025-04-05 10:55:02 +02:00
Aryan
b30cf5d452 spatio temporal guidance 2025-04-05 10:39:46 +02:00
Aryan
357f4f056b update 2025-04-04 22:44:02 +02:00
Aryan
53b6b9fcb6 perturbed attention guidance 2025-04-04 04:16:46 +02:00
Aryan
46643564a3 refactor 2025-04-04 01:41:34 +02:00
Aryan
77324c40c4 adaptive projected guidance 2025-04-03 05:01:54 +02:00
Aryan
05d74ef3e7 cfg zero star 2025-04-03 04:21:24 +02:00
Aryan
9997c223a8 more slg improvements 2025-04-03 03:50:30 +02:00
Aryan
d91d10737a update slg docstring 2025-04-03 03:43:10 +02:00
Aryan
5ac7f360b0 skip layer guidance 2025-04-03 03:26:55 +02:00
Aryan
594e8d663f classifier-free guidance 2025-04-03 00:13:15 +02:00
Aryan
c76e1cc17e update 2025-04-02 21:52:33 +02:00
Aryan
315e357a18 Merge branch 'main' into integrations/first-block-cache-2 2025-04-02 01:21:22 +02:00
Aryan
1f33ca276d support flux, ltx i2v, ltx condition 2025-04-02 01:21:09 +02:00
Aryan
41b0c473d2 fix controlnet flux 2025-04-02 01:20:53 +02:00
Aryan
0e232ac8c0 fix hs residual bug for single return outputs; support ltx 2025-04-02 00:38:11 +02:00
Aryan
2557238b4d cache context for different batches of data 2025-04-01 19:40:23 +02:00
Aryan
d71fe55895 update 2025-04-01 17:06:45 +02:00
Aryan
7ab424a15a remove debug logs 2025-04-01 01:39:00 +02:00
Aryan
dd69b41834 modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) 2025-04-01 01:28:09 +02:00
Aryan
406b1062f8 update 2025-03-31 04:27:35 +02:00
34 changed files with 2290 additions and 106 deletions

View File

@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
pipe.transformer.enable_cache(config)
```
## First Block Cache
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
```python
import torch
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
config = FirstBlockCacheConfig(threshold=0.07)
pipe.transformer.enable_cache(config)
```
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast

View File

@@ -33,6 +33,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
@@ -129,12 +130,25 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -708,11 +722,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (

View File

@@ -0,0 +1,24 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance

View File

@@ -0,0 +1,145 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class AdaptiveProjectedGuidance(GuidanceMixin):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, *args):
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return super().prepare_inputs(*args)
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred

View File

@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeGuidance(GuidanceMixin):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
):
super().__init__()
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)

View File

@@ -0,0 +1,110 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeZeroStarGuidance(GuidanceMixin):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.type_as(cond)

View File

@@ -0,0 +1,213 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..utils import deprecate, get_logger
if TYPE_CHECKING:
from ..models.attention_processor import AttentionProcessor
logger = get_logger(__name__) # pylint: disable=invalid-name
class GuidanceMixin:
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
_input_predictions = None
def __init__(self):
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._preds = {}
self._num_outputs_prepared = 0
def prepare_models(self, denoiser: torch.nn.Module) -> None:
pass
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
# Alternating conditional and unconditional inputs as batches
inputs = [arg[i % 2] for i in range(num_conditions)]
list_of_inputs.append(inputs)
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
pass
def __call__(self, **kwargs) -> Any:
if len(kwargs) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
)
return self.forward(**kwargs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
@property
def num_conditions(self) -> int:
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
@property
def outputs(self) -> Dict[str, torch.Tensor]:
return self._preds
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def _replace_attention_processors(
module: torch.nn.Module,
pag_applied_layers: Optional[Union[str, List[str]]] = None,
skip_context_attention: bool = False,
processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None,
metadata_name: Optional[str] = None,
) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]:
if processors is not None and metadata_name is not None:
raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.")
if metadata_name is not None:
if isinstance(pag_applied_layers, str):
pag_applied_layers = [pag_applied_layers]
return _replace_layers_with_guidance_processors(
module, pag_applied_layers, skip_context_attention, metadata_name
)
if processors is not None:
_replace_layers_with_existing_processors(processors)
def _replace_layers_with_guidance_processors(
module: torch.nn.Module,
pag_applied_layers: List[str],
skip_context_attention: bool,
metadata_name: str,
) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]:
from ..hooks._common import _ATTENTION_CLASSES
from ..hooks._helpers import GuidanceMetadataRegistry
processors = []
for name, submodule in module.named_modules():
if (
(not isinstance(submodule, _ATTENTION_CLASSES))
or (getattr(submodule, "processor", None) is None)
or not (
any(
re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name)
for pag_layer in pag_applied_layers
)
)
):
continue
old_attention_processor = submodule.processor
metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__)
new_attention_processor_cls = getattr(metadata, metadata_name)
new_attention_processor = new_attention_processor_cls()
# !!! dunder methods cannot be replaced on instances !!!
# if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters:
# new_attention_processor.__call__ = partial(
# new_attention_processor.__call__, skip_context_attention=skip_context_attention
# )
submodule.processor = new_attention_processor
processors.append((submodule, old_attention_processor))
return processors
def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None:
for module, proc in processors:
module.processor = proc
def _is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
def _raise_guidance_deprecation_warning(
*,
guidance_scale: Optional[float] = None,
guidance_rescale: Optional[float] = None,
) -> None:
if guidance_scale is not None:
msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_scale", "1.0.0", msg, standard_warn=False)
if guidance_rescale is not None:
msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False)

View File

@@ -0,0 +1,180 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg
class PerturbedAttentionGuidance(GuidanceMixin):
"""
Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377
Args:
pag_applied_layers (`str` or `List[str]`):
The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer
name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention
layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply
PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10",
"transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`.
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
pag_scale (`float`, defaults to `3.0`):
The scale parameter for perturbed attention guidance.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"]
def __init__(
self,
pag_applied_layers: Union[str, List[str]],
guidance_scale: float = 7.5,
pag_scale: float = 3.0,
skip_context_attention: bool = False,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.pag_applied_layers = pag_applied_layers
self.guidance_scale = guidance_scale
self.pag_scale = pag_scale
self.skip_context_attention = skip_context_attention
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self._is_pag_batch = False
self._original_processors = None
self._denoiser = None
def prepare_models(self, denoiser: torch.nn.Module):
self._denoiser = denoiser
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_pag_enabled():
# If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_perturbed"
self._preds[key] = pred
# Restore the original attention processors if previously replaced
if self._is_pag_batch:
_replace_attention_processors(self._denoiser, processors=self._original_processors)
self._is_pag_batch = False
self._original_processors = None
# Prepare denoiser for perturbed attention prediction if needed
if self._is_pag_enabled():
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
)
if should_register_pag:
self._is_pag_batch = True
self._original_processors = _replace_attention_processors(
self._denoiser,
self.pag_applied_layers,
skip_context_attention=self.skip_context_attention,
metadata_name="perturbed_attention_guidance_processor_cls",
)
def cleanup_models(self, denoiser: torch.nn.Module):
self._denoiser = None
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_perturbed: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_pag_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_perturbed
pred = pred_cond + self.pag_scale * shift
elif not self._is_pag_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_perturbed = pred_cond - pred_perturbed
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_pag_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_pag_enabled(self) -> bool:
is_zero = math.isclose(self.pag_scale, 0.0)
return not is_zero

View File

@@ -0,0 +1,235 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class SkipLayerGuidance(GuidanceMixin):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (0.0 < skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module):
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
# Register the hooks for layer skipping if the step is within the specified range
if skip_start_step < self._step < skip_stop_step:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_slg_enabled():
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_cond_skip"
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_slg_enabled(self) -> bool:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -1,9 +1,25 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

View File

@@ -0,0 +1,32 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)

View File

@@ -0,0 +1,276 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Type
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import (
CogView4AttnProcessor,
CogView4PAGAttnProcessor,
CogView4TransformerBlock,
)
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class GuidanceMetadata:
perturbed_attention_guidance_processor_cls: Type = None
@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
class AttentionProcessorRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class GuidanceMetadataRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: GuidanceMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> GuidanceMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
def _register_attention_processors_metadata():
# CogView4
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_guidance_metadata():
# CogView4
GuidanceMetadataRegistry.register(
model_class=CogView4AttnProcessor,
metadata=GuidanceMetadata(
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on
_register_attention_processors_metadata()
_register_guidance_metadata()
_register_transformer_blocks_metadata()

View File

@@ -0,0 +1,223 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseMarkedState, HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
poorer generation quality. A lower threshold may not result in significant generation speedup. The
threshold is compared against the absmean difference of the residuals between the current and cached
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseMarkedState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
self.shared_state = shared_state
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hs_residual = output[self._metadata.return_hidden_states_index] - original_hs
else:
hs_residual = output - original_hs
hs, ehs = None, None
should_compute = self._should_compute_remaining_blocks(hs_residual)
self.shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
else:
hs = self.shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
ehs = (
self.shared_state.tail_block_residuals[1]
+ output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hs
return_output[self._metadata.return_encoder_hidden_states_index] = ehs
return_output = tuple(return_output)
else:
return_output = hs
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
self.shared_state.head_block_output = head_block_output
self.shared_state.head_block_residual = hs_residual
return output
def reset_state(self, module):
self.shared_state.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
if self.shared_state.head_block_residual is None:
return True
prev_hs_residual = self.shared_state.head_block_residual
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
prev_hs_mean = prev_hs_residual.abs().mean()
diff = (hs_absmean / prev_hs_mean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
super().__init__()
self.shared_state = shared_state
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
if not isinstance(outputs_if_skipped, tuple):
outputs_if_skipped = (outputs_if_skipped,)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
original_ehs = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
if self.shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hs_residual, ehs_residual = None, None
if isinstance(output, tuple):
hs_residual = (
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
)
ehs_residual = (
output[self._metadata.return_encoder_hidden_states_index]
- self.shared_state.head_block_output[1]
)
else:
hs_residual = output - self.shared_state.head_block_output
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
return output
output_count = len(outputs_if_skipped)
if output_count == 1:
return_output = original_hs
else:
return_output = [None] * output_count
return_output[self._metadata.return_hidden_states_index] = original_hs
return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
shared_state = FBCSharedBlockState()
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Apply FBCBlockHook to '{name}'")
apply_fbc_block_hook(block, shared_state)
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)

View File

@@ -18,11 +18,76 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class BaseMarkedState(BaseState):
def __init__(self, init_args=None, init_kwargs=None):
super().__init__()
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._mark_name = None
self._state_cache = {}
def get_current_state(self) -> "BaseMarkedState":
if self._mark_name is None:
# If no mark name is set, simply return a dummy object since we're not going to be using it
return self
if self._mark_name not in self._state_cache.keys():
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
return self._state_cache[self._mark_name]
def mark_state(self, name: str) -> None:
self._mark_name = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._mark_name = None
def __getattribute__(self, name):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
return object.__getattribute__(self, name)
else:
current_state = BaseMarkedState.get_current_state(self)
return object.__getattribute__(current_state, name)
def __setattr__(self, name, value):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
object.__setattr__(self, name, value)
else:
current_state = BaseMarkedState.get_current_state(self)
object.__setattr__(current_state, name, value)
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +164,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseMarkedState):
attr.mark_state(name)
return module
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +284,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +297,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _mark_state(self, name: str) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._mark_state(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._mark_state(name)
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
@@ -234,3 +321,7 @@ class HookRegistry:
if i < len(self._hook_order) - 1:
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"
def _is_dunder_method(name: str) -> bool:
return name.startswith("__") and name.endswith("__") and name in dir(object)

View File

@@ -0,0 +1,182 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __init__(self) -> None:
super().__init__()
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
value = kwargs.get("value", None)
if value is None:
value = args[2]
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
with AttentionScoreSkipFunctionMode():
return self.fn_ref.original_forward(*args, **kwargs)
else:
return self.skip_processor_output_fn(module, *args, **kwargs)
class FeedForwardSkipHook(ModelHook):
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
return output
class TransformerBlockSkipHook(ModelHook):
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = getattr(module, config.fqn, None)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook()
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
registry.register_hook(hook, name)
elif config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook()
registry.register_hook(hook, name)
else:
raise ValueError(
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry = HookRegistry.check_if_exists_or_initialize(self)
if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,21 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def _cache_context(self):
r"""Context manager that provides additional methods for cache management."""
cache_context = _CacheContextManager(self)
yield cache_context
class _CacheContextManager:
def __init__(self, model: CacheMixin):
self.model = model
def mark_state(self, name: str) -> None:
from ..hooks import HookRegistry
if self.model.is_cache_enabled:
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry._mark_state(name)

View File

@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()

View File

@@ -460,3 +460,84 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
### ===== Custom attention processors for guidance methods ===== ###
class CogView4PAGAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
skip_context_attention: bool = False,
) -> torch.Tensor:
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
# 4. Attention
if skip_context_attention:
hidden_states = value
else:
# PAG uses a custom attention mask for perturbed attention path:
# - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens
# - Set diagonal to `0.0` for attention between image tokens
seq_length = text_seq_length + image_seq_length
perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf"))
perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0
perturbed_attention_mask.fill_diagonal_(0.0)
perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states

View File

@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -508,20 +513,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -531,12 +537,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

View File

@@ -21,6 +21,7 @@ import torch
from transformers import AutoTokenizer, GlmModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, CogView4Transformer2DModel
@@ -426,6 +427,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 1024,
guidance: Optional[GuidanceMixin] = None,
) -> Union[CogView4PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -514,6 +516,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
if guidance is None:
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -608,46 +614,47 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
self._current_timestep = t
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
guidance.prepare_models(self.transformer)
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
latents,
(prompt_embeds[0], negative_prompt_embeds[0]),
original_size[0],
target_size[0],
crops_coords_top_left[0],
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
):
cc.mark_state(f"batch_{batch_index}")
latent = latent.to(transformer_dtype)
timestep = t.expand(latent.shape[0])
noise_pred = self.transformer(
hidden_states=latent,
encoder_hidden_states=condition,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
original_size=original_size_c,
target_size=target_size_c,
crop_coords=crop_coord_c,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
guidance.prepare_outputs(noise_pred)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
outputs = guidance.outputs
noise_pred = guidance(**outputs)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
guidance.cleanup_models(self.transformer)
# call the callback, if provided
if callback_on_step_end is not None:
@@ -656,8 +663,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
negative_prompt_embeds = [
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
]
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

View File

@@ -906,7 +906,7 @@ class FluxPipeline(
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -917,6 +917,7 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
@@ -932,6 +933,8 @@ class FluxPipeline(
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
cc.mark_state("uncond")
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,

View File

@@ -683,7 +683,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -693,6 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
@@ -705,6 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
)[0]
if do_true_cfg:
cc.mark_state("uncond")
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,

View File

@@ -706,7 +706,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -719,6 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -1072,7 +1072,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -1105,6 +1105,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -778,7 +778,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -792,6 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -519,7 +519,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -528,6 +528,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
@@ -537,6 +538,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)[0]
if self.do_classifier_free_guidance:
cc.mark_state("uncond")
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,

View File

@@ -2,6 +2,81 @@
from ..utils import DummyObject, requires_backends
class AdaptiveProjectedGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ClassifierFreeGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PerturbedAttentionGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SkipLayerGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -17,6 +92,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FirstBlockCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -32,6 +122,21 @@ class HookRegistry(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LayerSkipConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -51,6 +156,14 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
def apply_first_block_cache(*args, **kwargs):
requires_backends(apply_first_block_cache, ["torch"])
def apply_layer_skip(*args, **kwargs):
requires_backends(apply_layer_skip, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

View File

@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).

View File

@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -44,7 +45,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}

View File

@@ -25,6 +25,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
@@ -34,11 +35,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
num_layers=num_layers,
caption_channels=32,
)

View File

@@ -33,13 +33,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
class MochiPipelineFastTests(
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

View File

@@ -33,6 +33,7 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -2631,7 +2632,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
output = run_forward(pipe).flatten().flatten()
output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2738,6 +2739,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
# that will not satisfy the expected properties of the denoiser for caching to be effective
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
return pipe(**inputs)[0]
# Run inference without FirstBlockCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache enabled
pipe = create_pipe()
pipe.transformer.enable_cache(self.first_block_cache_config)
output = run_forward(pipe).flatten()
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
), "FirstBlockCache outputs should not differ much."
assert np.allclose(
original_image_slice, image_slice_fbc_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.