Compare commits

...

3 Commits

Author SHA1 Message Date
Daniel Gu
e27c4d3acc Modify AttentionProcessorSkipHook to support _skip_attn_scores flag on attn processors to allow custom STG-style logic 2026-03-07 01:09:59 +01:00
tcaimm
e747fe4a94 Fix wrapped transformer config access in Flux2 Klein training (#13219) 2026-03-06 19:47:51 +05:30
Shenghai Yuan
46bd005730 Convert tensors to float in Helios's optimized_scale function (#13214)
Convert tensors to float in optimized_scale function
2026-03-05 17:08:44 -10:00
4 changed files with 30 additions and 5 deletions

View File

@@ -1715,7 +1715,7 @@ def main(args):
packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:

View File

@@ -1682,7 +1682,7 @@ def main(args):
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:

View File

@@ -108,8 +108,17 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
def __init__(
self,
skip_processor_output_fn: Callable,
skip_attn_scores_fn: Callable | None = None,
skip_attention_scores: bool = False,
dropout: float = 1.0,
):
super().__init__()
self.skip_processor_output_fn = skip_processor_output_fn
# STG default: return the values as attention output
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
@@ -119,8 +128,22 @@ class AttentionProcessorSkipHook(ModelHook):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
if processor_supports_skip_fn:
module.processor._skip_attn_scores = True
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
# Use try block in case attn processor raises an exception
try:
if processor_supports_skip_fn:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
finally:
if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)

View File

@@ -76,6 +76,8 @@ EXAMPLE_DOC_STRING = """
def optimized_scale(positive_flat, negative_flat):
positive_flat = positive_flat.float()
negative_flat = negative_flat.float()
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition