Compare commits

..

3 Commits

4 changed files with 30 additions and 14 deletions

View File

@@ -1715,7 +1715,7 @@ def main(args):
packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:

View File

@@ -1682,7 +1682,7 @@ def main(args):
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:

View File

@@ -108,8 +108,17 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
def __init__(
self,
skip_processor_output_fn: Callable,
skip_attention_scores: bool = False,
dropout: float = 1.0,
skip_attn_scores_fn: Callable | None = None,
):
super().__init__()
self.skip_processor_output_fn = skip_processor_output_fn
# STG default: return the values as attention output
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
@@ -119,8 +128,22 @@ class AttentionProcessorSkipHook(ModelHook):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
if processor_supports_skip_fn:
module.processor._skip_attn_scores = True
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
# Use try block in case attn processor raises an exception
try:
if processor_supports_skip_fn:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
finally:
if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)

View File

@@ -703,16 +703,9 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
try:
from kernels import get_kernel
is_blackwell = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10 # sm_100+
version = 2 if is_blackwell and backend == AttentionBackendName.SAGE_HUB else config.version
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=version)
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
if needs_kernel:
function_attr = (
"sageattn3_blackwell"
if is_blackwell and backend == AttentionBackendName.SAGE_HUB
else config.function_attr
)
config.kernel_fn = _resolve_kernel_attr(kernel_module, function_attr)
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)