mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-09 10:11:43 +08:00
Compare commits
3 Commits
sage-black
...
skip-layer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82da133a38 | ||
|
|
e27c4d3acc | ||
|
|
e747fe4a94 |
@@ -1715,7 +1715,7 @@ def main(args):
|
||||
packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
|
||||
@@ -1682,7 +1682,7 @@ def main(args):
|
||||
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
|
||||
@@ -108,8 +108,17 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
||||
def __init__(
|
||||
self,
|
||||
skip_processor_output_fn: Callable,
|
||||
skip_attention_scores: bool = False,
|
||||
dropout: float = 1.0,
|
||||
skip_attn_scores_fn: Callable | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
# STG default: return the values as attention output
|
||||
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
self.dropout = dropout
|
||||
|
||||
@@ -119,8 +128,22 @@ class AttentionProcessorSkipHook(ModelHook):
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
|
||||
if processor_supports_skip_fn:
|
||||
module.processor._skip_attn_scores = True
|
||||
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
|
||||
# Use try block in case attn processor raises an exception
|
||||
try:
|
||||
if processor_supports_skip_fn:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
# Fallback to torch native SDPA intercept approach
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
finally:
|
||||
if processor_supports_skip_fn:
|
||||
module.processor._skip_attn_scores = False
|
||||
module.processor._skip_attn_scores_fn = None
|
||||
else:
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
|
||||
@@ -703,16 +703,9 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
is_blackwell = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10 # sm_100+
|
||||
version = 2 if is_blackwell and backend == AttentionBackendName.SAGE_HUB else config.version
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=version)
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
|
||||
if needs_kernel:
|
||||
function_attr = (
|
||||
"sageattn3_blackwell"
|
||||
if is_blackwell and backend == AttentionBackendName.SAGE_HUB
|
||||
else config.function_attr
|
||||
)
|
||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, function_attr)
|
||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
||||
|
||||
if needs_wrapped_forward:
|
||||
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
||||
|
||||
Reference in New Issue
Block a user