Compare commits

..

1 Commits

Author SHA1 Message Date
cbensimon
bc8817604a [core] Add export-safe LRU cache helper 2026-03-19 17:30:58 +00:00
5 changed files with 29 additions and 49 deletions

View File

@@ -143,7 +143,6 @@ Refer to the table below for a complete list of available attention backends and
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
return tensor
@functools.lru_cache(maxsize=64)
@lru_cache_unless_export(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import maybe_allow_in_graph
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -229,7 +229,6 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -359,11 +358,6 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -527,7 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -582,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,
@@ -2683,37 +2676,6 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -25,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs

View File

@@ -19,11 +19,16 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -333,5 +338,21 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()