mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-05 02:15:13 +08:00
Compare commits
1 Commits
modular-te
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8da128067c |
@@ -111,57 +111,3 @@ config = TaylorSeerCacheConfig(
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## MagCache
|
||||
|
||||
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
|
||||
|
||||
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
|
||||
|
||||
### Usage
|
||||
|
||||
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
|
||||
|
||||
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
|
||||
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, MagCacheConfig
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# 1. Calibration Step
|
||||
# Run full inference to measure model behavior.
|
||||
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
|
||||
pipe.transformer.enable_cache(calib_config)
|
||||
|
||||
# Run a prompt to trigger calibration
|
||||
pipe("A cat playing chess", num_inference_steps=4)
|
||||
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
|
||||
|
||||
# 2. Inference Step
|
||||
# Apply the specific ratios obtained from calibration for optimized speed.
|
||||
# Note: For Flux models, you can also import defaults:
|
||||
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
|
||||
mag_config = MagCacheConfig(
|
||||
mag_ratios=[1.0, 1.37, 0.97, 0.87],
|
||||
num_inference_steps=4
|
||||
)
|
||||
|
||||
pipe.transformer.enable_cache(mag_config)
|
||||
|
||||
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.
|
||||
|
||||
@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
|
||||
@@ -168,14 +168,12 @@ else:
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"MagCacheConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_mag_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
]
|
||||
@@ -934,14 +932,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .mag_cache import MagCacheConfig, apply_mag_cache
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
@@ -23,13 +23,7 @@ from ..models.attention_processor import Attention, MochiAttention
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
|
||||
"blocks",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"layers",
|
||||
"visual_transformer_blocks",
|
||||
)
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ class AttentionProcessorMetadata:
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
hidden_states_argument_name: str = "hidden_states"
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
@@ -170,7 +169,7 @@ def _register_attention_processors_metadata():
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_bria import BriaTransformerBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
@@ -185,7 +184,6 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
@@ -333,24 +331,6 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=JointTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=Kandinsky5TransformerDecoderBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
hidden_states_argument_name="visual_embed",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
|
||||
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
|
||||
|
||||
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
|
||||
# Users must explicitly pass these to the config if using Flux.
|
||||
# Reference: https://github.com/Zehong-Ma/MagCache
|
||||
FLUX_MAG_RATIOS = torch.tensor(
|
||||
[1.0]
|
||||
+ [
|
||||
1.21094,
|
||||
1.11719,
|
||||
1.07812,
|
||||
1.0625,
|
||||
1.03906,
|
||||
1.03125,
|
||||
1.03906,
|
||||
1.02344,
|
||||
1.03125,
|
||||
1.02344,
|
||||
0.98047,
|
||||
1.01562,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.0,
|
||||
0.99609,
|
||||
0.99609,
|
||||
0.98047,
|
||||
0.98828,
|
||||
0.96484,
|
||||
0.95703,
|
||||
0.93359,
|
||||
0.89062,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate the source array to the target length using nearest neighbor interpolation.
|
||||
"""
|
||||
src_length = len(src_array)
|
||||
if target_length == 1:
|
||||
return src_array[-1:]
|
||||
|
||||
scale = (src_length - 1) / (target_length - 1)
|
||||
grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32)
|
||||
mapped_indices = torch.round(grid * scale).long()
|
||||
return src_array[mapped_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MagCacheConfig:
|
||||
r"""
|
||||
Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.06`):
|
||||
The threshold for the accumulated error. If the accumulated error is below this threshold, the block
|
||||
computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade
|
||||
quality.
|
||||
max_skip_steps (`int`, defaults to `3`):
|
||||
The maximum number of consecutive steps that can be skipped (K in the paper).
|
||||
retention_ratio (`float`, defaults to `0.2`):
|
||||
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
|
||||
`num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped.
|
||||
num_inference_steps (`int`, defaults to `28`):
|
||||
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
|
||||
mag_ratios (`torch.Tensor`, *optional*):
|
||||
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
|
||||
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
|
||||
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
|
||||
calibrate (`bool`, defaults to `False`):
|
||||
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
|
||||
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
|
||||
models or schedulers.
|
||||
"""
|
||||
|
||||
threshold: float = 0.06
|
||||
max_skip_steps: int = 3
|
||||
retention_ratio: float = 0.2
|
||||
num_inference_steps: int = 28
|
||||
mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None
|
||||
calibrate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# User MUST provide ratios OR enable calibration.
|
||||
if self.mag_ratios is None and not self.calibrate:
|
||||
raise ValueError(
|
||||
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
|
||||
"To get them for your model:\n"
|
||||
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
|
||||
"2. Run inference on your model once.\n"
|
||||
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
|
||||
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
|
||||
)
|
||||
|
||||
if not self.calibrate and self.mag_ratios is not None:
|
||||
if not torch.is_tensor(self.mag_ratios):
|
||||
self.mag_ratios = torch.tensor(self.mag_ratios)
|
||||
|
||||
if len(self.mag_ratios) != self.num_inference_steps:
|
||||
logger.debug(
|
||||
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
|
||||
)
|
||||
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
|
||||
|
||||
|
||||
class MagCacheState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Cache for the residual (output - input) from the *previous* timestep
|
||||
self.previous_residual: torch.Tensor = None
|
||||
|
||||
# State inputs/outputs for the current forward pass
|
||||
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
# MagCache accumulators
|
||||
self.accumulated_ratio: float = 1.0
|
||||
self.accumulated_err: float = 0.0
|
||||
self.accumulated_steps: int = 0
|
||||
|
||||
# Current step counter (timestep index)
|
||||
self.step_index: int = 0
|
||||
|
||||
# Calibration storage
|
||||
self.calibration_ratios: List[float] = []
|
||||
|
||||
def reset(self):
|
||||
self.previous_residual = None
|
||||
self.should_compute = True
|
||||
self.accumulated_ratio = 1.0
|
||||
self.accumulated_err = 0.0
|
||||
self.accumulated_steps = 0
|
||||
self.step_index = 0
|
||||
self.calibration_ratios = []
|
||||
|
||||
|
||||
class MagCacheHeadHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, config: MagCacheConfig):
|
||||
self.state_manager = state_manager
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
state.head_block_input = hidden_states
|
||||
|
||||
should_compute = True
|
||||
|
||||
if self.config.calibrate:
|
||||
# Never skip during calibration
|
||||
should_compute = True
|
||||
else:
|
||||
# MagCache Logic
|
||||
current_step = state.step_index
|
||||
if current_step >= len(self.config.mag_ratios):
|
||||
current_scale = 1.0
|
||||
else:
|
||||
current_scale = self.config.mag_ratios[current_step]
|
||||
|
||||
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
|
||||
|
||||
if current_step >= retention_step:
|
||||
state.accumulated_ratio *= current_scale
|
||||
state.accumulated_steps += 1
|
||||
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
|
||||
|
||||
if (
|
||||
state.previous_residual is not None
|
||||
and state.accumulated_err <= self.config.threshold
|
||||
and state.accumulated_steps <= self.config.max_skip_steps
|
||||
):
|
||||
should_compute = False
|
||||
else:
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
|
||||
state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
logger.debug(f"MagCache: Skipping step {state.step_index}")
|
||||
# Apply MagCache: Output = Input + Previous Residual
|
||||
|
||||
output = hidden_states
|
||||
res = state.previous_residual
|
||||
|
||||
if res.device != output.device:
|
||||
res = res.to(output.device)
|
||||
|
||||
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
|
||||
if res.shape == output.shape:
|
||||
output = output + res
|
||||
elif (
|
||||
output.ndim == 3
|
||||
and res.ndim == 3
|
||||
and output.shape[0] == res.shape[0]
|
||||
and output.shape[2] == res.shape[2]
|
||||
):
|
||||
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
|
||||
diff = output.shape[1] - res.shape[1]
|
||||
if diff > 0:
|
||||
output = output.clone()
|
||||
output[:, diff:, :] = output[:, diff:, :] + res
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = output
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
else:
|
||||
return output
|
||||
|
||||
else:
|
||||
# Compute original forward
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
class MagCacheBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
|
||||
if not state.should_compute:
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Still need to advance step index even if we skip
|
||||
self._advance_step(state)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = hidden_states
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
|
||||
return hidden_states
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Calculate residual for next steps
|
||||
if isinstance(output, tuple):
|
||||
out_hidden = output[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
out_hidden = output
|
||||
|
||||
in_hidden = state.head_block_input
|
||||
|
||||
if in_hidden is None:
|
||||
return output
|
||||
|
||||
# Determine residual
|
||||
if out_hidden.shape == in_hidden.shape:
|
||||
residual = out_hidden - in_hidden
|
||||
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
|
||||
diff = in_hidden.shape[1] - out_hidden.shape[1]
|
||||
if diff == 0:
|
||||
residual = out_hidden - in_hidden
|
||||
else:
|
||||
residual = out_hidden - in_hidden # Fallback to matching tail
|
||||
else:
|
||||
# Fallback for completely mismatched shapes
|
||||
residual = out_hidden
|
||||
|
||||
if self.config.calibrate:
|
||||
self._perform_calibration_step(state, residual)
|
||||
|
||||
state.previous_residual = residual
|
||||
self._advance_step(state)
|
||||
|
||||
return output
|
||||
|
||||
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
|
||||
if state.previous_residual is None:
|
||||
# First step has no previous residual to compare against.
|
||||
# log 1.0 as a neutral starting point.
|
||||
ratio = 1.0
|
||||
else:
|
||||
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
|
||||
# norm(dim=-1) gives magnitude of each token vector
|
||||
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
|
||||
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
|
||||
|
||||
# Avoid division by zero
|
||||
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
|
||||
|
||||
state.calibration_ratios.append(ratio)
|
||||
|
||||
def _advance_step(self, state: MagCacheState):
|
||||
state.step_index += 1
|
||||
if state.step_index >= self.config.num_inference_steps:
|
||||
# End of inference loop
|
||||
if self.config.calibrate:
|
||||
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
|
||||
print(f"{state.calibration_ratios}\n")
|
||||
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
|
||||
|
||||
# Reset state
|
||||
state.step_index = 0
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
state.previous_residual = None
|
||||
state.calibration_ratios = []
|
||||
|
||||
|
||||
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
|
||||
"""
|
||||
Applies MagCache to a given module (typically a Transformer).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply MagCache to.
|
||||
config (`MagCacheConfig`):
|
||||
The configuration for MagCache.
|
||||
"""
|
||||
# Initialize registry on the root module so the Pipeline can set context.
|
||||
HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
state_manager = StateManager(MagCacheState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
if not remaining_blocks:
|
||||
logger.warning("MagCache: No transformer blocks found to apply hooks.")
|
||||
return
|
||||
|
||||
# Handle single-block models
|
||||
if len(remaining_blocks) == 1:
|
||||
name, block = remaining_blocks[0]
|
||||
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
|
||||
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
|
||||
_apply_mag_cache_head_hook(block, state_manager, config)
|
||||
return
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.info(f"MagCache: Applying Head Hook to {head_block_name}")
|
||||
_apply_mag_cache_head_hook(head_block, state_manager, config)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
_apply_mag_cache_block_hook(block, state_manager, config)
|
||||
|
||||
logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}")
|
||||
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
|
||||
|
||||
|
||||
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application (e.g. switching modes)
|
||||
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheHeadHook(state_manager, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_mag_cache_block_hook(
|
||||
block: torch.nn.Module,
|
||||
state_manager: StateManager,
|
||||
config: MagCacheConfig,
|
||||
is_tail: bool = False,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application
|
||||
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheBlockHook(state_manager, is_tail, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
|
||||
@@ -68,12 +68,10 @@ class CacheMixin:
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
@@ -87,8 +85,6 @@ class CacheMixin:
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, MagCacheConfig):
|
||||
apply_mag_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
@@ -103,13 +99,11 @@ class CacheMixin:
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
|
||||
@@ -124,9 +118,6 @@ class CacheMixin:
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
|
||||
@@ -302,7 +302,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_2"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
|
||||
@@ -80,7 +80,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
]
|
||||
@@ -99,7 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -193,7 +193,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -210,7 +210,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -270,7 +270,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
@@ -290,7 +290,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -405,7 +405,7 @@ class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
@@ -431,7 +431,7 @@ class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -715,7 +715,7 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("prompt", required=True),
|
||||
InputParam.template("prompt"),
|
||||
InputParam.template("negative_prompt"),
|
||||
InputParam.template("max_sequence_length", default=1024),
|
||||
]
|
||||
@@ -844,7 +844,7 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam.template("prompt", required=True),
|
||||
InputParam.template("prompt"),
|
||||
InputParam.template("negative_prompt"),
|
||||
InputParam(
|
||||
name="resized_image",
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_2"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("negative_prompt_2"),
|
||||
|
||||
@@ -179,7 +179,7 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
]
|
||||
|
||||
@@ -149,7 +149,7 @@ class ZImageTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
]
|
||||
|
||||
@@ -227,21 +227,6 @@ class LayerSkipConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MagCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -299,10 +284,6 @@ def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_mag_cache(*args, **kwargs):
|
||||
requires_backends(apply_mag_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import MagCacheConfig, apply_mag_cache
|
||||
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Output is double input
|
||||
# This ensures Residual = 2*Input - Input = Input
|
||||
return hidden_states * 2.0
|
||||
|
||||
|
||||
class DummyTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TupleOutputBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Returns a tuple
|
||||
return hidden_states * 2.0, encoder_hidden_states
|
||||
|
||||
|
||||
class TupleTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
# Emulate Flux-like behavior
|
||||
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = output[0]
|
||||
encoder_hidden_states = output[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class MagCacheTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Register standard dummy block
|
||||
TransformerBlockRegistry.register(
|
||||
DummyBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
|
||||
)
|
||||
# Register tuple block (Flux style)
|
||||
TransformerBlockRegistry.register(
|
||||
TupleOutputBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
|
||||
)
|
||||
|
||||
def _set_context(self, model, context_name):
|
||||
"""Helper to set context on all hooks in the model."""
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(context_name)
|
||||
|
||||
def _get_calibration_data(self, model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
|
||||
if hook:
|
||||
return hook.state_manager.get_state().calibration_ratios
|
||||
return []
|
||||
|
||||
def test_mag_cache_validation(self):
|
||||
"""Test that missing mag_ratios raises ValueError."""
|
||||
with self.assertRaises(ValueError):
|
||||
MagCacheConfig(num_inference_steps=10, calibrate=False)
|
||||
|
||||
def test_mag_cache_skipping_logic(self):
|
||||
"""
|
||||
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
|
||||
"""
|
||||
model = DummyTransformer()
|
||||
|
||||
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=0.0, # Enable immediate skipping
|
||||
max_skip_steps=5,
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
|
||||
# HeadInput=10. Output=40. Residual=30.
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
output_t0 = model(input_t0)
|
||||
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
|
||||
|
||||
# Step 1: Input 11.0.
|
||||
# If Skipped: Output = Input(11) + Residual(30) = 41.0
|
||||
# If Computed: Output = 11 * 4 = 44.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_retention(self):
|
||||
"""Test that retention_ratio prevents skipping even if error is low."""
|
||||
model = DummyTransformer()
|
||||
# Ratios that imply 0 error, so it *would* skip if retention allowed it
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=1.0, # Force retention for ALL steps
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
|
||||
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
|
||||
)
|
||||
|
||||
def test_mag_cache_tuple_outputs(self):
|
||||
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
|
||||
model = TupleTransformer()
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
|
||||
# Residual = 10.0
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
enc_t0 = torch.tensor([[[1.0]]])
|
||||
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
|
||||
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
|
||||
|
||||
# Step 1: Skip. Input 11.0.
|
||||
# Skipped Output = 11 + 10 = 21.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_reset(self):
|
||||
"""Test that state resets correctly after num_inference_steps."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
|
||||
)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
input_t = torch.ones(1, 1, 1)
|
||||
|
||||
model(input_t) # Step 0
|
||||
model(input_t) # Step 1 (Skipped)
|
||||
|
||||
# Step 2 (Reset -> Step 0) -> Should Compute
|
||||
# Input 2.0 -> Output 8.0
|
||||
input_t2 = torch.tensor([[[2.0]]])
|
||||
output_t2 = model(input_t2)
|
||||
|
||||
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
|
||||
|
||||
def test_mag_cache_calibration(self):
|
||||
"""Test that calibration mode records ratios."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# Ratio 0 is placeholder 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Check intermediate state
|
||||
ratios = self._get_calibration_data(model)
|
||||
self.assertEqual(len(ratios), 1)
|
||||
self.assertEqual(ratios[0], 1.0)
|
||||
|
||||
# Step 1
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# PrevResidual = 30. CurrResidual = 30.
|
||||
# Ratio = 30/30 = 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Verify it computes fully (no skip)
|
||||
# If it skipped, output would be 41.0. It should be 40.0
|
||||
# Actually in test setup, input is same (10.0) so output 40.0.
|
||||
# Let's ensure list is empty after reset (end of step 1)
|
||||
ratios_after = self._get_calibration_data(model)
|
||||
self.assertEqual(ratios_after, [])
|
||||
@@ -37,14 +37,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type", "height", "width"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
@@ -68,21 +63,10 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"max_sequence_length",
|
||||
]
|
||||
)
|
||||
decode_block_params = frozenset(["output_type", "height", "width"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = super().get_pipeline(components_manager, torch_dtype)
|
||||
|
||||
@@ -145,13 +129,9 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["latents"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
@@ -32,15 +32,9 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.2-dev"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
@@ -66,14 +60,9 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
@@ -32,15 +32,10 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
default_repo_id = None # TODO
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
@@ -64,15 +59,10 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
default_repo_id = None # TODO
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
|
||||
@@ -32,14 +32,10 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2-klein"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
@@ -63,15 +59,10 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2-klein"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
|
||||
@@ -34,16 +34,10 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
@@ -66,16 +60,10 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["image", "prompt", "negative_prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "generator"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
@@ -98,17 +86,11 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["image", "prompt", "negative_prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "generator"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
|
||||
@@ -279,8 +279,6 @@ class TestSDXLModularPipelineFast(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
|
||||
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -293,11 +291,6 @@ class TestSDXLModularPipelineFast(
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
@@ -333,7 +326,6 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -347,11 +339,6 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
@@ -392,7 +379,6 @@ class SDXLInpaintingModularPipelineFastTests(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
|
||||
@@ -37,8 +37,6 @@ class ModularPipelineTesterMixin:
|
||||
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(["generator"])
|
||||
# prompt is required for most pipeline, with exceptions like qwen-image layer
|
||||
required_params = frozenset(["prompt"])
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
@@ -57,12 +55,6 @@ class ModularPipelineTesterMixin:
|
||||
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_repo_id(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
@@ -105,33 +97,6 @@ class ModularPipelineTesterMixin:
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def text_encoder_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `text_encoder_block_params` in the child test class. "
|
||||
"`text_encoder_block_params` are the parameters required to be passed to the text encoder block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def decode_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `decode_block_params` in the child test class. "
|
||||
"`decode_block_params` are the parameters required to be passed to the decode block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def vae_encoder_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
|
||||
"`vae_encoder_block_params` are the parameters required to be passed to the vae encoder block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setup_method(self):
|
||||
# clean up the VRAM before each test
|
||||
torch.compiler.reset()
|
||||
@@ -156,7 +121,6 @@ class ModularPipelineTesterMixin:
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
required_parameters = pipe.blocks.required_inputs
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
@@ -166,101 +130,6 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
_check_for_parameters(self.required_params, required_parameters, "required")
|
||||
|
||||
def test_loading_from_default_repo(self):
|
||||
if self.default_repo_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
|
||||
assert pipe.blocks.__class__ == self.pipeline_blocks_class
|
||||
except Exception as e:
|
||||
assert False, f"Failed to load pipeline from default repo: {e}"
|
||||
|
||||
def test_modular_inference(self):
|
||||
# run the pipeline to get the base output for comparison
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
standard_output = pipe(**inputs, output="images")
|
||||
|
||||
# create text, denoise, decoder (and optional vae encoder) nodes
|
||||
blocks = self.pipeline_blocks_class()
|
||||
|
||||
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
|
||||
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
|
||||
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
|
||||
if self.vae_encoder_block_params is not None:
|
||||
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
|
||||
|
||||
# manually set the components in the sub_pipe
|
||||
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
|
||||
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
|
||||
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
|
||||
for n, comp in pipe.components.items():
|
||||
if not hasattr(sub_pipe, n):
|
||||
setattr(sub_pipe, n, comp)
|
||||
|
||||
# Initialize all nodes
|
||||
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
text_node.load_components(torch_dtype=torch.float32)
|
||||
text_node.to(torch_device)
|
||||
manually_set_all_components(pipe, text_node)
|
||||
|
||||
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
denoise_node.load_components(torch_dtype=torch.float32)
|
||||
denoise_node.to(torch_device)
|
||||
manually_set_all_components(pipe, denoise_node)
|
||||
|
||||
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
decoder_node.load_components(torch_dtype=torch.float32)
|
||||
decoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, decoder_node)
|
||||
|
||||
if self.vae_encoder_block_params is not None:
|
||||
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
vae_encoder_node.load_components(torch_dtype=torch.float32)
|
||||
vae_encoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, vae_encoder_node)
|
||||
else:
|
||||
vae_encoder_node = None
|
||||
|
||||
def filter_inputs(available: dict, expected_keys) -> dict:
|
||||
return {k: v for k, v in available.items() if k in expected_keys}
|
||||
|
||||
# prepare inputs for each node
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
# 1. Text encoder: takes from inputs
|
||||
text_inputs = filter_inputs(inputs, self.text_encoder_block_params)
|
||||
text_output = text_node(**text_inputs)
|
||||
text_output_dict = text_output.get_by_kwargs("denoiser_input_fields")
|
||||
|
||||
# 2. VAE encoder (optional): takes from inputs + text_output
|
||||
if vae_encoder_node is not None:
|
||||
vae_available = {**inputs, **text_output_dict}
|
||||
vae_encoder_inputs = filter_inputs(vae_available, vae_encoder_node.blocks.input_names)
|
||||
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs)
|
||||
vae_output_dict = vae_encoder_output.values
|
||||
else:
|
||||
vae_output_dict = {}
|
||||
|
||||
# 3. Denoise: takes from inputs + text_output + vae_output
|
||||
denoise_available = {**inputs, **text_output_dict, **vae_output_dict}
|
||||
denoise_inputs = filter_inputs(denoise_available, denoise_node.blocks.input_names)
|
||||
denoise_output = denoise_node(**denoise_inputs)
|
||||
latents = denoise_output.latents
|
||||
|
||||
# 4. Decoder: takes from inputs + denoise_output
|
||||
decode_available = {**inputs, "latents": latents}
|
||||
decode_inputs = filter_inputs(decode_available, decoder_node.blocks.input_names)
|
||||
modular_output = decoder_node(**decode_inputs).images
|
||||
|
||||
assert modular_output.shape == standard_output.shape, (
|
||||
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
|
||||
)
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
@@ -27,7 +27,6 @@ from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
@@ -42,7 +41,6 @@ class FluxPipelineFastTests(
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
|
||||
@@ -35,7 +35,6 @@ from diffusers import (
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.mag_cache import MagCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@@ -2977,59 +2976,6 @@ class TaylorSeerCacheTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class MagCacheTesterMixin:
|
||||
mag_cache_config = MagCacheConfig(
|
||||
threshold=0.06,
|
||||
max_skip_steps=3,
|
||||
retention_ratio=0.2,
|
||||
num_inference_steps=50,
|
||||
mag_ratios=torch.ones(50),
|
||||
)
|
||||
|
||||
def test_mag_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu"
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Match the config steps
|
||||
inputs["num_inference_steps"] = 50
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# 1. Run inference without MagCache (Baseline)
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 2. Run inference with MagCache ENABLED
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.mag_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 3. Run inference with MagCache DISABLED
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), (
|
||||
"MagCache outputs should not differ too much from baseline."
|
||||
)
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), (
|
||||
"Outputs after disabling cache should match original inference exactly."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user