mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
21 Commits
kontext-re
...
feature/gu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1da7752e5 | ||
|
|
b30cf5d452 | ||
|
|
357f4f056b | ||
|
|
53b6b9fcb6 | ||
|
|
46643564a3 | ||
|
|
77324c40c4 | ||
|
|
05d74ef3e7 | ||
|
|
9997c223a8 | ||
|
|
d91d10737a | ||
|
|
5ac7f360b0 | ||
|
|
594e8d663f | ||
|
|
c76e1cc17e | ||
|
|
315e357a18 | ||
|
|
1f33ca276d | ||
|
|
41b0c473d2 | ||
|
|
0e232ac8c0 | ||
|
|
2557238b4d | ||
|
|
d71fe55895 | ||
|
|
7ab424a15a | ||
|
|
dd69b41834 | ||
|
|
406b1062f8 |
@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Caching methods
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## First Block Cache
|
||||
|
||||
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
|
||||
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
|
||||
config = FirstBlockCacheConfig(threshold=0.07)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
@@ -33,6 +33,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"guiders": [],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
@@ -129,12 +130,25 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
]
|
||||
)
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -708,11 +722,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
)
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
|
||||
24
src/diffusers/guiders/__init__.py
Normal file
24
src/diffusers/guiders/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
145
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
145
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(GuidanceMixin):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, *args):
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
return super().prepare_inputs(*args)
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + (guidance_scale - 1) * normalized_update
|
||||
return pred
|
||||
98
src/diffusers/guiders/classifier_free_guidance.py
Normal file
98
src/diffusers/guiders/classifier_free_guidance.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(GuidanceMixin):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity.
|
||||
|
||||
The original paper proposes scaling and shifting the conditional distribution based on the difference between
|
||||
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
110
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
110
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(GuidanceMixin):
|
||||
"""
|
||||
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
||||
|
||||
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
||||
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
||||
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
||||
quality of generated images.
|
||||
|
||||
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
zero_init_steps (`int`, defaults to `1`):
|
||||
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
zero_init_steps: int = 1,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred_cond_flat = pred_cond.flatten(1)
|
||||
pred_uncond_flat = pred_uncond.flatten(1)
|
||||
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
||||
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
||||
pred_uncond = pred_uncond * alpha
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond = cond.float()
|
||||
uncond = uncond.float()
|
||||
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
scale = dot_product / squared_norm
|
||||
return scale.type_as(cond)
|
||||
213
src/diffusers/guiders/guider_utils.py
Normal file
213
src/diffusers/guiders/guider_utils.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import deprecate, get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GuidanceMixin:
|
||||
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
_input_predictions = None
|
||||
|
||||
def __init__(self):
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._preds: Dict[str, torch.Tensor] = {}
|
||||
self._num_outputs_prepared: int = 0
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._preds = {}
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
# Alternating conditional and unconditional inputs as batches
|
||||
inputs = [arg[i % 2] for i in range(num_conditions)]
|
||||
list_of_inputs.append(inputs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
self._preds[key] = pred
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
if len(kwargs) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
|
||||
)
|
||||
return self.forward(**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[str, torch.Tensor]:
|
||||
return self._preds
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
def _replace_attention_processors(
|
||||
module: torch.nn.Module,
|
||||
pag_applied_layers: Optional[Union[str, List[str]]] = None,
|
||||
skip_context_attention: bool = False,
|
||||
processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None,
|
||||
metadata_name: Optional[str] = None,
|
||||
) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]:
|
||||
if processors is not None and metadata_name is not None:
|
||||
raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.")
|
||||
if metadata_name is not None:
|
||||
if isinstance(pag_applied_layers, str):
|
||||
pag_applied_layers = [pag_applied_layers]
|
||||
return _replace_layers_with_guidance_processors(
|
||||
module, pag_applied_layers, skip_context_attention, metadata_name
|
||||
)
|
||||
if processors is not None:
|
||||
_replace_layers_with_existing_processors(processors)
|
||||
|
||||
|
||||
def _replace_layers_with_guidance_processors(
|
||||
module: torch.nn.Module,
|
||||
pag_applied_layers: List[str],
|
||||
skip_context_attention: bool,
|
||||
metadata_name: str,
|
||||
) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]:
|
||||
from ..hooks._common import _ATTENTION_CLASSES
|
||||
from ..hooks._helpers import GuidanceMetadataRegistry
|
||||
|
||||
processors = []
|
||||
for name, submodule in module.named_modules():
|
||||
if (
|
||||
(not isinstance(submodule, _ATTENTION_CLASSES))
|
||||
or (getattr(submodule, "processor", None) is None)
|
||||
or not (
|
||||
any(
|
||||
re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name)
|
||||
for pag_layer in pag_applied_layers
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
old_attention_processor = submodule.processor
|
||||
metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__)
|
||||
new_attention_processor_cls = getattr(metadata, metadata_name)
|
||||
new_attention_processor = new_attention_processor_cls()
|
||||
# !!! dunder methods cannot be replaced on instances !!!
|
||||
# if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters:
|
||||
# new_attention_processor.__call__ = partial(
|
||||
# new_attention_processor.__call__, skip_context_attention=skip_context_attention
|
||||
# )
|
||||
submodule.processor = new_attention_processor
|
||||
processors.append((submodule, old_attention_processor))
|
||||
return processors
|
||||
|
||||
|
||||
def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None:
|
||||
for module, proc in processors:
|
||||
module.processor = proc
|
||||
|
||||
|
||||
def _is_fake_integral_match(layer_id, name):
|
||||
layer_id = layer_id.split(".")[-1]
|
||||
name = name.split(".")[-1]
|
||||
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
|
||||
|
||||
|
||||
def _raise_guidance_deprecation_warning(
|
||||
*,
|
||||
guidance_scale: Optional[float] = None,
|
||||
guidance_rescale: Optional[float] = None,
|
||||
) -> None:
|
||||
if guidance_scale is not None:
|
||||
msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
|
||||
deprecate("guidance_scale", "1.0.0", msg, standard_warn=False)
|
||||
if guidance_rescale is not None:
|
||||
msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
|
||||
deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False)
|
||||
180
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
180
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(GuidanceMixin):
|
||||
"""
|
||||
Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377
|
||||
|
||||
Args:
|
||||
pag_applied_layers (`str` or `List[str]`):
|
||||
The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer
|
||||
name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention
|
||||
layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply
|
||||
PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10",
|
||||
"transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`.
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
pag_scale (`float`, defaults to `3.0`):
|
||||
The scale parameter for perturbed attention guidance.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pag_applied_layers: Union[str, List[str]],
|
||||
guidance_scale: float = 7.5,
|
||||
pag_scale: float = 3.0,
|
||||
skip_context_attention: bool = False,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pag_applied_layers = pag_applied_layers
|
||||
self.guidance_scale = guidance_scale
|
||||
self.pag_scale = pag_scale
|
||||
self.skip_context_attention = skip_context_attention
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
self._is_pag_batch = False
|
||||
self._original_processors = None
|
||||
self._denoiser = None
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module):
|
||||
self._denoiser = denoiser
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
list_of_inputs.append([arg[0], arg[1], arg[0]])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
if not self._is_cfg_enabled() and self._is_pag_enabled():
|
||||
# If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed
|
||||
# to avoid writing into pred_uncond which is not used
|
||||
if self._num_outputs_prepared == 2:
|
||||
key = "pred_perturbed"
|
||||
self._preds[key] = pred
|
||||
|
||||
# Restore the original attention processors if previously replaced
|
||||
if self._is_pag_batch:
|
||||
_replace_attention_processors(self._denoiser, processors=self._original_processors)
|
||||
self._is_pag_batch = False
|
||||
self._original_processors = None
|
||||
|
||||
# Prepare denoiser for perturbed attention prediction if needed
|
||||
if self._is_pag_enabled():
|
||||
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
|
||||
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
|
||||
)
|
||||
if should_register_pag:
|
||||
self._is_pag_batch = True
|
||||
self._original_processors = _replace_attention_processors(
|
||||
self._denoiser,
|
||||
self.pag_applied_layers,
|
||||
skip_context_attention=self.skip_context_attention,
|
||||
metadata_name="perturbed_attention_guidance_processor_cls",
|
||||
)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
self._denoiser = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_perturbed: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_pag_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_perturbed
|
||||
pred = pred_cond + self.pag_scale * shift
|
||||
elif not self._is_pag_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_perturbed = pred_cond - pred_perturbed
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_pag_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
def _is_pag_enabled(self) -> bool:
|
||||
is_zero = math.isclose(self.pag_scale, 0.0)
|
||||
return not is_zero
|
||||
235
src/diffusers/guiders/skip_layer_guidance.py
Normal file
235
src/diffusers/guiders/skip_layer_guidance.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class SkipLayerGuidance(GuidanceMixin):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
|
||||
https://huggingface.co/papers/2411.18664
|
||||
|
||||
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
||||
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
||||
batch of data, apart from the conditional and unconditional batches already used in CFG
|
||||
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
||||
based on the difference between conditional without skipping and conditional with skipping predictions.
|
||||
|
||||
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
||||
version of the model for the conditional prediction).
|
||||
|
||||
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
||||
generation quality in video diffusion models.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
||||
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (0.0 < skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_layer_guidance_layers is not None:
|
||||
if isinstance(skip_layer_guidance_layers, int):
|
||||
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
||||
if not isinstance(skip_layer_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module):
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
# Register the hooks for layer skipping if the step is within the specified range
|
||||
if skip_start_step < self._step < skip_stop_step:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
list_of_inputs.append([arg[0], arg[1], arg[0]])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
if not self._is_cfg_enabled() and self._is_slg_enabled():
|
||||
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
|
||||
# to avoid writing into pred_uncond which is not used
|
||||
if self._num_outputs_prepared == 2:
|
||||
key = "pred_cond_skip"
|
||||
self._preds[key] = pred
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
return is_within_range and not is_zero
|
||||
@@ -1,9 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
32
src/diffusers/hooks/_common.py
Normal file
32
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.attention import FeedForward, LuminaFeedForward
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
276
src/diffusers/hooks/_helpers.py
Normal file
276
src/diffusers/hooks/_helpers.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_cogview4 import (
|
||||
CogView4AttnProcessor,
|
||||
CogView4PAGAttnProcessor,
|
||||
CogView4TransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from ..models.transformers.transformer_hunyuan_video import (
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionProcessorMetadata:
|
||||
skip_processor_output_fn: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceMetadata:
|
||||
perturbed_attention_guidance_processor_cls: Type = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
skip_block_output_fn: Callable[[Any], Any]
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
|
||||
class AttentionProcessorRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
class GuidanceMetadataRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: GuidanceMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> GuidanceMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
def _register_attention_processors_metadata():
|
||||
# CogView4
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_guidance_metadata():
|
||||
# CogView4
|
||||
GuidanceMetadataRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=GuidanceMetadata(
|
||||
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
# CogVideoX
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogVideoXBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogView4TransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Flux
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# LTXVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=LTXVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Mochi
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=MochiTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=WanTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
|
||||
_register_attention_processors_metadata()
|
||||
_register_guidance_metadata()
|
||||
_register_transformer_blocks_metadata()
|
||||
223
src/diffusers/hooks/first_block_cache.py
Normal file
223
src/diffusers/hooks/first_block_cache.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseMarkedState, HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
|
||||
poorer generation quality. A lower threshold may not result in significant generation speedup. The
|
||||
threshold is compared against the absmean difference of the residuals between the current and cached
|
||||
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
|
||||
skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(BaseMarkedState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.tail_block_residuals = None
|
||||
self.should_compute = True
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
|
||||
self.shared_state = shared_state
|
||||
self.threshold = threshold
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
is_output_tuple = isinstance(output, tuple)
|
||||
|
||||
if is_output_tuple:
|
||||
hs_residual = output[self._metadata.return_hidden_states_index] - original_hs
|
||||
else:
|
||||
hs_residual = output - original_hs
|
||||
|
||||
hs, ehs = None, None
|
||||
should_compute = self._should_compute_remaining_blocks(hs_residual)
|
||||
self.shared_state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
if is_output_tuple:
|
||||
hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
hs = self.shared_state.tail_block_residuals[0] + output
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
assert is_output_tuple
|
||||
ehs = (
|
||||
self.shared_state.tail_block_residuals[1]
|
||||
+ output[self._metadata.return_encoder_hidden_states_index]
|
||||
)
|
||||
|
||||
if is_output_tuple:
|
||||
return_output = [None] * len(output)
|
||||
return_output[self._metadata.return_hidden_states_index] = hs
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = ehs
|
||||
return_output = tuple(return_output)
|
||||
else:
|
||||
return_output = hs
|
||||
output = return_output
|
||||
else:
|
||||
if is_output_tuple:
|
||||
head_block_output = [None] * len(output)
|
||||
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
||||
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
||||
else:
|
||||
head_block_output = output
|
||||
self.shared_state.head_block_output = head_block_output
|
||||
self.shared_state.head_block_residual = hs_residual
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.shared_state.reset()
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
|
||||
if self.shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hs_residual = self.shared_state.head_block_residual
|
||||
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
|
||||
prev_hs_mean = prev_hs_residual.abs().mean()
|
||||
diff = (hs_absmean / prev_hs_mean).item()
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.shared_state = shared_state
|
||||
self.is_tail = is_tail
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
if not isinstance(outputs_if_skipped, tuple):
|
||||
outputs_if_skipped = (outputs_if_skipped,)
|
||||
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
|
||||
original_ehs = None
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
|
||||
|
||||
if self.shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hs_residual, ehs_residual = None, None
|
||||
if isinstance(output, tuple):
|
||||
hs_residual = (
|
||||
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
|
||||
)
|
||||
ehs_residual = (
|
||||
output[self._metadata.return_encoder_hidden_states_index]
|
||||
- self.shared_state.head_block_output[1]
|
||||
)
|
||||
else:
|
||||
hs_residual = output - self.shared_state.head_block_output
|
||||
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
|
||||
return output
|
||||
|
||||
output_count = len(outputs_if_skipped)
|
||||
if output_count == 1:
|
||||
return_output = original_hs
|
||||
else:
|
||||
return_output = [None] * output_count
|
||||
return_output[self._metadata.return_hidden_states_index] = original_hs
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
shared_state = FBCSharedBlockState()
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
|
||||
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Apply FBCBlockHook to '{name}'")
|
||||
apply_fbc_block_hook(block, shared_state)
|
||||
|
||||
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
|
||||
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||
|
||||
|
||||
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -18,11 +18,76 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseState:
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"BaseState::reset is not implemented. Please implement this method in the derived class."
|
||||
)
|
||||
|
||||
|
||||
class BaseMarkedState(BaseState):
|
||||
def __init__(self, init_args=None, init_kwargs=None):
|
||||
super().__init__()
|
||||
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._mark_name = None
|
||||
self._state_cache = {}
|
||||
|
||||
def get_current_state(self) -> "BaseMarkedState":
|
||||
if self._mark_name is None:
|
||||
# If no mark name is set, simply return a dummy object since we're not going to be using it
|
||||
return self
|
||||
if self._mark_name not in self._state_cache.keys():
|
||||
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._mark_name]
|
||||
|
||||
def mark_state(self, name: str) -> None:
|
||||
self._mark_name = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
for name, state in list(self._state_cache.items()):
|
||||
state.reset(*args, **kwargs)
|
||||
self._state_cache.pop(name)
|
||||
self._mark_name = None
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
"_mark_name",
|
||||
"_state_cache",
|
||||
) or _is_dunder_method(name):
|
||||
return object.__getattribute__(self, name)
|
||||
else:
|
||||
current_state = BaseMarkedState.get_current_state(self)
|
||||
return object.__getattribute__(current_state, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
"_mark_name",
|
||||
"_state_cache",
|
||||
) or _is_dunder_method(name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
current_state = BaseMarkedState.get_current_state(self)
|
||||
object.__setattr__(current_state, name, value)
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
@@ -99,6 +164,14 @@ class ModelHook:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, BaseMarkedState):
|
||||
attr.mark_state(name)
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
@@ -211,9 +284,10 @@ class HookRegistry:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@@ -223,6 +297,19 @@ class HookRegistry:
|
||||
module._diffusers_hook = cls(module)
|
||||
return module._diffusers_hook
|
||||
|
||||
def _mark_state(self, name: str) -> None:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook._mark_state(self._module_ref, name)
|
||||
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._mark_state(name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
@@ -234,3 +321,7 @@ class HookRegistry:
|
||||
if i < len(self._hook_order) - 1:
|
||||
registry_repr += "\n"
|
||||
return f"HookRegistry(\n{registry_repr}\n)"
|
||||
|
||||
|
||||
def _is_dunder_method(name: str) -> bool:
|
||||
return name.startswith("__") and name.endswith("__") and name in dir(object)
|
||||
|
||||
182
src/diffusers/hooks/layer_skip.py
Normal file
182
src/diffusers/hooks/layer_skip.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
skip_ff: bool = True
|
||||
|
||||
|
||||
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||
value = kwargs.get("value", None)
|
||||
if value is None:
|
||||
value = args[2]
|
||||
return value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False):
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
return self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
if output is None and len(args) > 0:
|
||||
output = args[0]
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlockSkipHook(ModelHook):
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.skip_attention and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = getattr(module, config.fqn, None)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
blocks_found = True
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook()
|
||||
registry.register_hook(hook, name)
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
|
||||
registry.register_hook(hook, name)
|
||||
elif config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook()
|
||||
registry.register_hook(hook, name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
|
||||
)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
@@ -25,6 +27,7 @@ class CacheMixin:
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -62,8 +65,10 @@ class CacheMixin:
|
||||
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
@@ -72,31 +77,36 @@ class CacheMixin:
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
if isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
return
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
if isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
@@ -106,3 +116,21 @@ class CacheMixin:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
||||
|
||||
@contextmanager
|
||||
def _cache_context(self):
|
||||
r"""Context manager that provides additional methods for cache management."""
|
||||
cache_context = _CacheContextManager(self)
|
||||
yield cache_context
|
||||
|
||||
|
||||
class _CacheContextManager:
|
||||
def __init__(self, model: CacheMixin):
|
||||
self.model = model
|
||||
|
||||
def mark_state(self, name: str) -> None:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
if self.model.is_cache_enabled:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry._mark_state(name)
|
||||
|
||||
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
single_block_samples = single_block_samples + (hidden_states,)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
|
||||
@@ -460,3 +460,84 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
### ===== Custom attention processors for guidance methods ===== ###
|
||||
|
||||
|
||||
class CogView4PAGAttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
skip_context_attention: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
# 4. Attention
|
||||
if skip_context_attention:
|
||||
hidden_states = value
|
||||
else:
|
||||
# PAG uses a custom attention mask for perturbed attention path:
|
||||
# - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens
|
||||
# - Set diagonal to `0.0` for attention between image tokens
|
||||
seq_length = text_seq_length + image_seq_length
|
||||
perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf"))
|
||||
perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0
|
||||
perturbed_attention_mask.fill_diagonal_(0.0)
|
||||
perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0)
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -508,20 +513,21 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
@@ -531,12 +537,7 @@ class FluxTransformer2DModel(
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import CogView4LoraLoaderMixin
|
||||
from ...models import AutoencoderKL, CogView4Transformer2DModel
|
||||
@@ -426,6 +427,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 1024,
|
||||
guidance: Optional[GuidanceMixin] = None,
|
||||
) -> Union[CogView4PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -514,6 +516,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
|
||||
if guidance is None:
|
||||
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
@@ -608,46 +614,47 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
|
||||
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
self._current_timestep = t
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
guidance.prepare_models(self.transformer)
|
||||
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
|
||||
latents,
|
||||
(prompt_embeds[0], negative_prompt_embeds[0]),
|
||||
original_size[0],
|
||||
target_size[0],
|
||||
crops_coords_top_left[0],
|
||||
)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
|
||||
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
|
||||
):
|
||||
cc.mark_state(f"batch_{batch_index}")
|
||||
latent = latent.to(transformer_dtype)
|
||||
timestep = t.expand(latent.shape[0])
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent,
|
||||
encoder_hidden_states=condition,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
original_size=original_size_c,
|
||||
target_size=target_size_c,
|
||||
crop_coords=crop_coord_c,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
guidance.prepare_outputs(noise_pred)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
outputs = guidance.outputs
|
||||
noise_pred = guidance(**outputs)
|
||||
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
|
||||
guidance.cleanup_models(self.transformer)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
@@ -656,8 +663,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
|
||||
negative_prompt_embeds = [
|
||||
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
|
||||
]
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -906,7 +906,7 @@ class FluxPipeline(
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -917,6 +917,7 @@ class FluxPipeline(
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
@@ -932,6 +933,8 @@ class FluxPipeline(
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
|
||||
cc.mark_state("uncond")
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
|
||||
@@ -683,7 +683,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -693,6 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
@@ -705,6 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
cc.mark_state("uncond")
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
|
||||
@@ -706,7 +706,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -719,6 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -1072,7 +1072,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -1105,6 +1105,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if is_conditioning_image_or_video:
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -778,7 +778,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -792,6 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -519,7 +519,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -528,6 +528,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
@@ -537,6 +538,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
cc.mark_state("uncond")
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
|
||||
@@ -2,6 +2,81 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SkipLayerGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FasterCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -17,6 +92,21 @@ class FasterCacheConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FirstBlockCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -32,6 +122,21 @@ class HookRegistry(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayerSkipConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -51,6 +156,14 @@ def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_first_block_cache(*args, **kwargs):
|
||||
requires_backends(apply_first_block_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def unwrap_module(module):
|
||||
"""Unwraps a module if it was compiled with torch.compile()"""
|
||||
return module._orig_mod if is_compiled_module(module) else module
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
@@ -44,7 +45,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
@@ -34,11 +35,12 @@ from ..test_pipelines_common import (
|
||||
|
||||
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
@@ -43,7 +44,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,13 +33,15 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
|
||||
class MochiPipelineFastTests(
|
||||
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = MochiPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -2631,7 +2632,7 @@ class FasterCacheTesterMixin:
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
output = run_forward(pipe).flatten().flatten()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache disabled
|
||||
@@ -2738,6 +2739,55 @@ class FasterCacheTesterMixin:
|
||||
self.assertTrue(state.cache is None, "Cache should be reset to None.")
|
||||
|
||||
|
||||
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
|
||||
# of the box once there is better cache support/implementation
|
||||
class FirstBlockCacheTesterMixin:
|
||||
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
|
||||
# that will not satisfy the expected properties of the denoiser for caching to be effective
|
||||
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
|
||||
|
||||
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without FirstBlockCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache enabled
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.first_block_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
|
||||
), "FirstBlockCache outputs should not differ much."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fbc_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user