Compare commits

..

1 Commits

Author SHA1 Message Date
teith
4bc1c59a67 fix: correct invalid type annotation for image in Flux2Pipeline.__call__ (#13205)
fix: correct invalid type annotation for image in Flux2Pipeline.__call__
2026-03-13 15:56:38 -03:00
2 changed files with 3 additions and 12 deletions

View File

@@ -2559,9 +2559,7 @@ def _flash_attention_3_hub(
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError(
"`attn_mask` is not supported for flash-attn 3. Please use the `_flash_3_varlen_hub` backend instead."
)
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
if _parallel_config is None:
@@ -2643,8 +2641,6 @@ def _flash_attention_3_varlen_hub(
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
if attn_mask.dtype != torch.bool:
attn_mask = attn_mask > -1
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
@@ -2664,7 +2660,7 @@ def _flash_attention_3_varlen_hub(
value_packed = torch.cat(value_valid, dim=0)
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
result = func(
out, lse, *_ = func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2675,11 +2671,6 @@ def _flash_attention_3_varlen_hub(
softmax_scale=scale,
causal=is_causal,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -744,7 +744,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: list[PIL.Image.Image, PIL.Image.Image] | None = None,
image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
prompt: str | list[str] = None,
height: int | None = None,
width: int | None = None,