mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-08 01:31:50 +08:00
Compare commits
3 Commits
v0.37.0
...
skip-layer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e27c4d3acc | ||
|
|
e747fe4a94 | ||
|
|
46bd005730 |
@@ -1715,7 +1715,7 @@ def main(args):
|
||||
packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
|
||||
@@ -1682,7 +1682,7 @@ def main(args):
|
||||
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
|
||||
@@ -108,8 +108,17 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
||||
def __init__(
|
||||
self,
|
||||
skip_processor_output_fn: Callable,
|
||||
skip_attn_scores_fn: Callable | None = None,
|
||||
skip_attention_scores: bool = False,
|
||||
dropout: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
# STG default: return the values as attention output
|
||||
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
self.dropout = dropout
|
||||
|
||||
@@ -119,8 +128,22 @@ class AttentionProcessorSkipHook(ModelHook):
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
|
||||
if processor_supports_skip_fn:
|
||||
module.processor._skip_attn_scores = True
|
||||
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
|
||||
# Use try block in case attn processor raises an exception
|
||||
try:
|
||||
if processor_supports_skip_fn:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
# Fallback to torch native SDPA intercept approach
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
finally:
|
||||
if processor_supports_skip_fn:
|
||||
module.processor._skip_attn_scores = False
|
||||
module.processor._skip_attn_scores_fn = None
|
||||
else:
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
|
||||
@@ -76,6 +76,8 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
def optimized_scale(positive_flat, negative_flat):
|
||||
positive_flat = positive_flat.float()
|
||||
negative_flat = negative_flat.float()
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
# Squared norm of uncondition
|
||||
|
||||
Reference in New Issue
Block a user