Compare commits

...

1 Commits

View File

@@ -108,8 +108,17 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
def __init__(
self,
skip_processor_output_fn: Callable,
skip_attn_scores_fn: Callable | None = None,
skip_attention_scores: bool = False,
dropout: float = 1.0,
):
super().__init__()
self.skip_processor_output_fn = skip_processor_output_fn
# STG default: return the values as attention output
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
@@ -119,8 +128,22 @@ class AttentionProcessorSkipHook(ModelHook):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
if processor_supports_skip_fn:
module.processor._skip_attn_scores = True
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
# Use try block in case attn processor raises an exception
try:
if processor_supports_skip_fn:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
finally:
if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)