Compare commits

..

1 Commits

Author SHA1 Message Date
Sayak Paul
8da128067c Fix syntax error in quantization configuration 2026-02-04 10:10:14 +05:30
25 changed files with 18 additions and 1112 deletions

View File

@@ -111,57 +111,3 @@ config = TaylorSeerCacheConfig(
)
pipe.transformer.enable_cache(config)
```
## MagCache
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
### Usage
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
```python
import torch
from diffusers import FluxPipeline, MagCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
).to("cuda")
# 1. Calibration Step
# Run full inference to measure model behavior.
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
pipe.transformer.enable_cache(calib_config)
# Run a prompt to trigger calibration
pipe("A cat playing chess", num_inference_steps=4)
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
# 2. Inference Step
# Apply the specific ratios obtained from calibration for optimized speed.
# Note: For Flux models, you can also import defaults:
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
mag_config = MagCacheConfig(
mag_ratios=[1.0, 1.37, 0.97, 0.87],
num_inference_steps=4
)
pipe.transformer.enable_cache(mag_config)
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
```
> [!NOTE]
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
> [!TIP]
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
> [!TIP]
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.

View File

@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
from torchao.quantization import Int4WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",

View File

@@ -168,14 +168,12 @@ else:
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"MagCacheConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
@@ -934,14 +932,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)

View File

@@ -23,7 +23,6 @@ if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .mag_cache import MagCacheConfig, apply_mag_cache
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

View File

@@ -23,13 +23,7 @@ from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
"blocks",
"transformer_blocks",
"single_transformer_blocks",
"layers",
"visual_transformer_blocks",
)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")

View File

@@ -26,7 +26,6 @@ class AttentionProcessorMetadata:
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
hidden_states_argument_name: str = "hidden_states"
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
@@ -170,7 +169,7 @@ def _register_attention_processors_metadata():
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
@@ -185,7 +184,6 @@ def _register_transformer_blocks_metadata():
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -333,24 +331,6 @@ def _register_transformer_blocks_metadata():
),
)
TransformerBlockRegistry.register(
model_class=JointTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
TransformerBlockRegistry.register(
model_class=Kandinsky5TransformerDecoderBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
hidden_states_argument_name="visual_embed",
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):

View File

@@ -1,468 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
# Users must explicitly pass these to the config if using Flux.
# Reference: https://github.com/Zehong-Ma/MagCache
FLUX_MAG_RATIOS = torch.tensor(
[1.0]
+ [
1.21094,
1.11719,
1.07812,
1.0625,
1.03906,
1.03125,
1.03906,
1.02344,
1.03125,
1.02344,
0.98047,
1.01562,
1.00781,
1.0,
1.00781,
1.0,
1.00781,
1.0,
1.0,
0.99609,
0.99609,
0.98047,
0.98828,
0.96484,
0.95703,
0.93359,
0.89062,
]
)
def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor:
"""
Interpolate the source array to the target length using nearest neighbor interpolation.
"""
src_length = len(src_array)
if target_length == 1:
return src_array[-1:]
scale = (src_length - 1) / (target_length - 1)
grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32)
mapped_indices = torch.round(grid * scale).long()
return src_array[mapped_indices]
@dataclass
class MagCacheConfig:
r"""
Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache).
Args:
threshold (`float`, defaults to `0.06`):
The threshold for the accumulated error. If the accumulated error is below this threshold, the block
computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade
quality.
max_skip_steps (`int`, defaults to `3`):
The maximum number of consecutive steps that can be skipped (K in the paper).
retention_ratio (`float`, defaults to `0.2`):
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
`num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped.
num_inference_steps (`int`, defaults to `28`):
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
mag_ratios (`torch.Tensor`, *optional*):
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
calibrate (`bool`, defaults to `False`):
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
models or schedulers.
"""
threshold: float = 0.06
max_skip_steps: int = 3
retention_ratio: float = 0.2
num_inference_steps: int = 28
mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None
calibrate: bool = False
def __post_init__(self):
# User MUST provide ratios OR enable calibration.
if self.mag_ratios is None and not self.calibrate:
raise ValueError(
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
"To get them for your model:\n"
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
"2. Run inference on your model once.\n"
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
)
if not self.calibrate and self.mag_ratios is not None:
if not torch.is_tensor(self.mag_ratios):
self.mag_ratios = torch.tensor(self.mag_ratios)
if len(self.mag_ratios) != self.num_inference_steps:
logger.debug(
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
)
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
class MagCacheState(BaseState):
def __init__(self) -> None:
super().__init__()
# Cache for the residual (output - input) from the *previous* timestep
self.previous_residual: torch.Tensor = None
# State inputs/outputs for the current forward pass
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
# MagCache accumulators
self.accumulated_ratio: float = 1.0
self.accumulated_err: float = 0.0
self.accumulated_steps: int = 0
# Current step counter (timestep index)
self.step_index: int = 0
# Calibration storage
self.calibration_ratios: List[float] = []
def reset(self):
self.previous_residual = None
self.should_compute = True
self.accumulated_ratio = 1.0
self.accumulated_err = 0.0
self.accumulated_steps = 0
self.step_index = 0
self.calibration_ratios = []
class MagCacheHeadHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, config: MagCacheConfig):
self.state_manager = state_manager
self.config = config
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
@torch.compiler.disable
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
arg_name = self._metadata.hidden_states_argument_name
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
state: MagCacheState = self.state_manager.get_state()
state.head_block_input = hidden_states
should_compute = True
if self.config.calibrate:
# Never skip during calibration
should_compute = True
else:
# MagCache Logic
current_step = state.step_index
if current_step >= len(self.config.mag_ratios):
current_scale = 1.0
else:
current_scale = self.config.mag_ratios[current_step]
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
if current_step >= retention_step:
state.accumulated_ratio *= current_scale
state.accumulated_steps += 1
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
if (
state.previous_residual is not None
and state.accumulated_err <= self.config.threshold
and state.accumulated_steps <= self.config.max_skip_steps
):
should_compute = False
else:
state.accumulated_ratio = 1.0
state.accumulated_steps = 0
state.accumulated_err = 0.0
state.should_compute = should_compute
if not should_compute:
logger.debug(f"MagCache: Skipping step {state.step_index}")
# Apply MagCache: Output = Input + Previous Residual
output = hidden_states
res = state.previous_residual
if res.device != output.device:
res = res.to(output.device)
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
if res.shape == output.shape:
output = output + res
elif (
output.ndim == 3
and res.ndim == 3
and output.shape[0] == res.shape[0]
and output.shape[2] == res.shape[2]
):
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
diff = output.shape[1] - res.shape[1]
if diff > 0:
output = output.clone()
output[:, diff:, :] = output[:, diff:, :] + res
else:
logger.warning(
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
"Cannot apply residual safely. Returning input without residual."
)
else:
logger.warning(
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
"Cannot apply residual safely. Returning input without residual."
)
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
max_idx = max(
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
)
ret_list = [None] * (max_idx + 1)
ret_list[self._metadata.return_hidden_states_index] = output
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return tuple(ret_list)
else:
return output
else:
# Compute original forward
output = self.fn_ref.original_forward(*args, **kwargs)
return output
def reset_state(self, module):
self.state_manager.reset()
return module
class MagCacheBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self.config = config
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
@torch.compiler.disable
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
state: MagCacheState = self.state_manager.get_state()
if not state.should_compute:
arg_name = self._metadata.hidden_states_argument_name
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
if self.is_tail:
# Still need to advance step index even if we skip
self._advance_step(state)
if self._metadata.return_encoder_hidden_states_index is not None:
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
max_idx = max(
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
)
ret_list = [None] * (max_idx + 1)
ret_list[self._metadata.return_hidden_states_index] = hidden_states
ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return tuple(ret_list)
return hidden_states
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
# Calculate residual for next steps
if isinstance(output, tuple):
out_hidden = output[self._metadata.return_hidden_states_index]
else:
out_hidden = output
in_hidden = state.head_block_input
if in_hidden is None:
return output
# Determine residual
if out_hidden.shape == in_hidden.shape:
residual = out_hidden - in_hidden
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
diff = in_hidden.shape[1] - out_hidden.shape[1]
if diff == 0:
residual = out_hidden - in_hidden
else:
residual = out_hidden - in_hidden # Fallback to matching tail
else:
# Fallback for completely mismatched shapes
residual = out_hidden
if self.config.calibrate:
self._perform_calibration_step(state, residual)
state.previous_residual = residual
self._advance_step(state)
return output
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
if state.previous_residual is None:
# First step has no previous residual to compare against.
# log 1.0 as a neutral starting point.
ratio = 1.0
else:
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
# norm(dim=-1) gives magnitude of each token vector
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
# Avoid division by zero
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
state.calibration_ratios.append(ratio)
def _advance_step(self, state: MagCacheState):
state.step_index += 1
if state.step_index >= self.config.num_inference_steps:
# End of inference loop
if self.config.calibrate:
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
print(f"{state.calibration_ratios}\n")
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
# Reset state
state.step_index = 0
state.accumulated_ratio = 1.0
state.accumulated_steps = 0
state.accumulated_err = 0.0
state.previous_residual = None
state.calibration_ratios = []
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
"""
Applies MagCache to a given module (typically a Transformer).
Args:
module (`torch.nn.Module`):
The module to apply MagCache to.
config (`MagCacheConfig`):
The configuration for MagCache.
"""
# Initialize registry on the root module so the Pipeline can set context.
HookRegistry.check_if_exists_or_initialize(module)
state_manager = StateManager(MagCacheState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
if not remaining_blocks:
logger.warning("MagCache: No transformer blocks found to apply hooks.")
return
# Handle single-block models
if len(remaining_blocks) == 1:
name, block = remaining_blocks[0]
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
_apply_mag_cache_head_hook(block, state_manager, config)
return
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.info(f"MagCache: Applying Head Hook to {head_block_name}")
_apply_mag_cache_head_hook(head_block, state_manager, config)
for name, block in remaining_blocks:
_apply_mag_cache_block_hook(block, state_manager, config)
logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}")
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
# Automatically remove existing hook to allow re-application (e.g. switching modes)
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
hook = MagCacheHeadHook(state_manager, config)
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
def _apply_mag_cache_block_hook(
block: torch.nn.Module,
state_manager: StateManager,
config: MagCacheConfig,
is_tail: bool = False,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
# Automatically remove existing hook to allow re-application
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
hook = MagCacheBlockHook(state_manager, is_tail, config)
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)

View File

@@ -68,12 +68,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
@@ -87,8 +85,6 @@ class CacheMixin:
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, MagCacheConfig):
apply_mag_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
@@ -103,13 +99,11 @@ class CacheMixin:
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
@@ -124,9 +118,6 @@ class CacheMixin:
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, MagCacheConfig):
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):

View File

@@ -302,7 +302,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("joint_attention_kwargs"),

View File

@@ -80,7 +80,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
]
@@ -99,7 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if not isinstance(prompt, str) and not isinstance(prompt, list):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
@@ -193,7 +193,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
]
@property
@@ -210,7 +210,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if not isinstance(prompt, str) and not isinstance(prompt, list):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@torch.no_grad()
@@ -270,7 +270,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@@ -290,7 +290,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
def check_inputs(block_state):
prompt = block_state.prompt
if not isinstance(prompt, str) and not isinstance(prompt, list):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
@@ -405,7 +405,7 @@ class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@@ -431,7 +431,7 @@ class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
def check_inputs(block_state):
prompt = block_state.prompt
if not isinstance(prompt, str) and not isinstance(prompt, list):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod

View File

@@ -715,7 +715,7 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("prompt", required=True),
InputParam.template("prompt"),
InputParam.template("negative_prompt"),
InputParam.template("max_sequence_length", default=1024),
]
@@ -844,7 +844,7 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("prompt", required=True),
InputParam.template("prompt"),
InputParam.template("negative_prompt"),
InputParam(
name="resized_image",

View File

@@ -244,7 +244,7 @@ class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),

View File

@@ -179,7 +179,7 @@ class WanTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam("max_sequence_length", default=512),
]

View File

@@ -149,7 +149,7 @@ class ZImageTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam("max_sequence_length", default=512),
]

View File

@@ -227,21 +227,6 @@ class LayerSkipConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MagCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -299,10 +284,6 @@ def apply_layer_skip(*args, **kwargs):
requires_backends(apply_layer_skip, ["torch"])
def apply_mag_cache(*args, **kwargs):
requires_backends(apply_mag_cache, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

View File

@@ -1,244 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import MagCacheConfig, apply_mag_cache
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
from diffusers.models import ModelMixin
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class DummyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
# Output is double input
# This ensures Residual = 2*Input - Input = Input
return hidden_states * 2.0
class DummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states
class TupleOutputBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
# Returns a tuple
return hidden_states * 2.0, encoder_hidden_states
class TupleTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
# Emulate Flux-like behavior
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = output[0]
encoder_hidden_states = output[1]
return hidden_states, encoder_hidden_states
class MagCacheTests(unittest.TestCase):
def setUp(self):
# Register standard dummy block
TransformerBlockRegistry.register(
DummyBlock,
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
)
# Register tuple block (Flux style)
TransformerBlockRegistry.register(
TupleOutputBlock,
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
)
def _set_context(self, model, context_name):
"""Helper to set context on all hooks in the model."""
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(context_name)
def _get_calibration_data(self, model):
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
if hook:
return hook.state_manager.get_state().calibration_ratios
return []
def test_mag_cache_validation(self):
"""Test that missing mag_ratios raises ValueError."""
with self.assertRaises(ValueError):
MagCacheConfig(num_inference_steps=10, calibrate=False)
def test_mag_cache_skipping_logic(self):
"""
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
"""
model = DummyTransformer()
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(
threshold=100.0,
num_inference_steps=2,
retention_ratio=0.0, # Enable immediate skipping
max_skip_steps=5,
mag_ratios=ratios,
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
# HeadInput=10. Output=40. Residual=30.
input_t0 = torch.tensor([[[10.0]]])
output_t0 = model(input_t0)
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
# Step 1: Input 11.0.
# If Skipped: Output = Input(11) + Residual(30) = 41.0
# If Computed: Output = 11 * 4 = 44.0
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
self.assertTrue(
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
)
def test_mag_cache_retention(self):
"""Test that retention_ratio prevents skipping even if error is low."""
model = DummyTransformer()
# Ratios that imply 0 error, so it *would* skip if retention allowed it
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(
threshold=100.0,
num_inference_steps=2,
retention_ratio=1.0, # Force retention for ALL steps
mag_ratios=ratios,
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0
model(torch.tensor([[[10.0]]]))
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
self.assertTrue(
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
)
def test_mag_cache_tuple_outputs(self):
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
model = TupleTransformer()
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
# Residual = 10.0
input_t0 = torch.tensor([[[10.0]]])
enc_t0 = torch.tensor([[[1.0]]])
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
# Step 1: Skip. Input 11.0.
# Skipped Output = 11 + 10 = 21.0
input_t1 = torch.tensor([[[11.0]]])
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
self.assertTrue(
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
)
def test_mag_cache_reset(self):
"""Test that state resets correctly after num_inference_steps."""
model = DummyTransformer()
config = MagCacheConfig(
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
input_t = torch.ones(1, 1, 1)
model(input_t) # Step 0
model(input_t) # Step 1 (Skipped)
# Step 2 (Reset -> Step 0) -> Should Compute
# Input 2.0 -> Output 8.0
input_t2 = torch.tensor([[[2.0]]])
output_t2 = model(input_t2)
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
def test_mag_cache_calibration(self):
"""Test that calibration mode records ratios."""
model = DummyTransformer()
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0
# HeadInput = 10. Output = 40. Residual = 30.
# Ratio 0 is placeholder 1.0
model(torch.tensor([[[10.0]]]))
# Check intermediate state
ratios = self._get_calibration_data(model)
self.assertEqual(len(ratios), 1)
self.assertEqual(ratios[0], 1.0)
# Step 1
# HeadInput = 10. Output = 40. Residual = 30.
# PrevResidual = 30. CurrResidual = 30.
# Ratio = 30/30 = 1.0
model(torch.tensor([[[10.0]]]))
# Verify it computes fully (no skip)
# If it skipped, output would be 41.0. It should be 40.0
# Actually in test setup, input is same (10.0) so output 40.0.
# Let's ensure list is empty after reset (end of step 1)
ratios_after = self._get_calibration_data(model)
self.assertEqual(ratios_after, [])

View File

@@ -37,14 +37,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
default_repo_id = "hf-internal-testing/tiny-flux-pipe"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type", "height", "width"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -68,21 +63,10 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
default_repo_id = "hf-internal-testing/tiny-flux-pipe"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(
[
"prompt",
"max_sequence_length",
]
)
decode_block_params = frozenset(["output_type", "height", "width"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = super().get_pipeline(components_manager, torch_dtype)
@@ -145,13 +129,9 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
default_repo_id = "hf-internal-testing/tiny-flux-kontext-pipe"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["latents"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -32,15 +32,9 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
default_repo_id = "black-forest-labs/FLUX.2-dev"
default_repo_id = "hf-internal-testing/tiny-flux2"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -66,14 +60,9 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
default_repo_id = "hf-internal-testing/tiny-flux2"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -32,15 +32,10 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
default_repo_id = None # TODO
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
@@ -64,15 +59,10 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
default_repo_id = None # TODO
params = frozenset(["prompt", "height", "width", "image"])
batch_params = frozenset(["prompt", "image"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {

View File

@@ -32,14 +32,10 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
default_repo_id = "hf-internal-testing/tiny-flux2-klein"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
@@ -63,15 +59,10 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
default_repo_id = "hf-internal-testing/tiny-flux2-klein"
params = frozenset(["prompt", "height", "width", "image"])
batch_params = frozenset(["prompt", "image"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {

View File

@@ -34,16 +34,10 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
default_repo_id = "Qwen/Qwen-Image"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
@@ -66,16 +60,10 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
default_repo_id = "Qwen/Qwen-Image-Edit"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["image", "prompt", "negative_prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "generator"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
@@ -98,17 +86,11 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
batch_params = frozenset(["prompt", "negative_prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["image", "prompt", "negative_prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "generator"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {

View File

@@ -279,8 +279,6 @@ class TestSDXLModularPipelineFast(
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
params = frozenset(
[
"prompt",
@@ -293,11 +291,6 @@ class TestSDXLModularPipelineFast(
batch_params = frozenset(["prompt", "negative_prompt"])
expected_image_output_shape = (1, 3, 64, 64)
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
@@ -333,7 +326,6 @@ class TestSDXLImg2ImgModularPipelineFast(
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
params = frozenset(
[
"prompt",
@@ -347,11 +339,6 @@ class TestSDXLImg2ImgModularPipelineFast(
batch_params = frozenset(["prompt", "negative_prompt", "image"])
expected_image_output_shape = (1, 3, 64, 64)
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
@@ -392,7 +379,6 @@ class SDXLInpaintingModularPipelineFastTests(
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
params = frozenset(
[
"prompt",

View File

@@ -37,8 +37,6 @@ class ModularPipelineTesterMixin:
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(["generator"])
# prompt is required for most pipeline, with exceptions like qwen-image layer
required_params = frozenset(["prompt"])
def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
@@ -57,12 +55,6 @@ class ModularPipelineTesterMixin:
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
)
@property
def default_repo_id(self) -> str:
raise NotImplementedError(
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
@@ -105,33 +97,6 @@ class ModularPipelineTesterMixin:
"See existing pipeline tests for reference."
)
@property
def text_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `text_encoder_block_params` in the child test class. "
"`text_encoder_block_params` are the parameters required to be passed to the text encoder block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
@property
def decode_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `decode_block_params` in the child test class. "
"`decode_block_params` are the parameters required to be passed to the decode block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
@property
def vae_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
"`vae_encoder_block_params` are the parameters required to be passed to the vae encoder block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def setup_method(self):
# clean up the VRAM before each test
torch.compiler.reset()
@@ -156,7 +121,6 @@ class ModularPipelineTesterMixin:
pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names
optional_parameters = pipe.default_call_parameters
required_parameters = pipe.blocks.required_inputs
def _check_for_parameters(parameters, expected_parameters, param_type):
remaining_parameters = {param for param in parameters if param not in expected_parameters}
@@ -166,101 +130,6 @@ class ModularPipelineTesterMixin:
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
_check_for_parameters(self.required_params, required_parameters, "required")
def test_loading_from_default_repo(self):
if self.default_repo_id is None:
return
try:
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
assert pipe.blocks.__class__ == self.pipeline_blocks_class
except Exception as e:
assert False, f"Failed to load pipeline from default repo: {e}"
def test_modular_inference(self):
# run the pipeline to get the base output for comparison
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
inputs = self.get_dummy_inputs()
standard_output = pipe(**inputs, output="images")
# create text, denoise, decoder (and optional vae encoder) nodes
blocks = self.pipeline_blocks_class()
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
if self.vae_encoder_block_params is not None:
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
# manually set the components in the sub_pipe
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
for n, comp in pipe.components.items():
if not hasattr(sub_pipe, n):
setattr(sub_pipe, n, comp)
# Initialize all nodes
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
text_node.load_components(torch_dtype=torch.float32)
text_node.to(torch_device)
manually_set_all_components(pipe, text_node)
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
denoise_node.load_components(torch_dtype=torch.float32)
denoise_node.to(torch_device)
manually_set_all_components(pipe, denoise_node)
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
decoder_node.load_components(torch_dtype=torch.float32)
decoder_node.to(torch_device)
manually_set_all_components(pipe, decoder_node)
if self.vae_encoder_block_params is not None:
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
vae_encoder_node.load_components(torch_dtype=torch.float32)
vae_encoder_node.to(torch_device)
manually_set_all_components(pipe, vae_encoder_node)
else:
vae_encoder_node = None
def filter_inputs(available: dict, expected_keys) -> dict:
return {k: v for k, v in available.items() if k in expected_keys}
# prepare inputs for each node
inputs = self.get_dummy_inputs()
# 1. Text encoder: takes from inputs
text_inputs = filter_inputs(inputs, self.text_encoder_block_params)
text_output = text_node(**text_inputs)
text_output_dict = text_output.get_by_kwargs("denoiser_input_fields")
# 2. VAE encoder (optional): takes from inputs + text_output
if vae_encoder_node is not None:
vae_available = {**inputs, **text_output_dict}
vae_encoder_inputs = filter_inputs(vae_available, vae_encoder_node.blocks.input_names)
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs)
vae_output_dict = vae_encoder_output.values
else:
vae_output_dict = {}
# 3. Denoise: takes from inputs + text_output + vae_output
denoise_available = {**inputs, **text_output_dict, **vae_output_dict}
denoise_inputs = filter_inputs(denoise_available, denoise_node.blocks.input_names)
denoise_output = denoise_node(**denoise_inputs)
latents = denoise_output.latents
# 4. Decoder: takes from inputs + denoise_output
decode_available = {**inputs, "latents": latents}
decode_inputs = filter_inputs(decode_available, decoder_node.blocks.input_names)
modular_output = decoder_node(**decode_inputs).images
assert modular_output.shape == standard_output.shape, (
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
)
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline().to(torch_device)

View File

@@ -27,7 +27,6 @@ from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
MagCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheTesterMixin,
@@ -42,7 +41,6 @@ class FluxPipelineFastTests(
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
TaylorSeerCacheTesterMixin,
MagCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline

View File

@@ -35,7 +35,6 @@ from diffusers import (
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.mag_cache import MagCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
from diffusers.image_processor import VaeImageProcessor
@@ -2977,59 +2976,6 @@ class TaylorSeerCacheTesterMixin:
)
class MagCacheTesterMixin:
mag_cache_config = MagCacheConfig(
threshold=0.06,
max_skip_steps=3,
retention_ratio=0.2,
num_inference_steps=50,
mag_ratios=torch.ones(50),
)
def test_mag_cache_inference(self, expected_atol: float = 0.1):
device = "cpu"
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
# Match the config steps
inputs["num_inference_steps"] = 50
return pipe(**inputs)[0]
# 1. Run inference without MagCache (Baseline)
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# 2. Run inference with MagCache ENABLED
pipe = create_pipe()
pipe.transformer.enable_cache(self.mag_cache_config)
output = run_forward(pipe).flatten()
image_slice_enabled = np.concatenate((output[:8], output[-8:]))
# 3. Run inference with MagCache DISABLED
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), (
"MagCache outputs should not differ too much from baseline."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), (
"Outputs after disabling cache should match original inference exactly."
)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.