Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul
611034eb74 Update docs/source/en/optimization/attention_backends.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-18 23:31:40 +05:30
Sayak Paul
052d5e6d5f Update attention_backends.md 2026-03-18 15:43:53 +05:30
8 changed files with 13 additions and 33 deletions

View File

@@ -35,7 +35,7 @@ The [`~ModelMixin.set_attention_backend`] method iterates through all the module
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [`kernels`](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
> [!NOTE]
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
> FlashAttention-3 requires Ampere GPUs at a minimum.
```py
import torch

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -31,7 +32,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -326,7 +327,7 @@ class PartitionAnythingSharder:
return tensor
@lru_cache_unless_export(maxsize=64)
@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ..utils.torch_utils import maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -575,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
@@ -342,6 +343,7 @@ class HeliosRotaryPosEmbed(nn.Module):
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -24,7 +25,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -306,7 +307,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -427,7 +428,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -449,7 +450,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs

View File

@@ -720,7 +720,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,

View File

@@ -35,8 +35,6 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
self.post_init()
def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output

View File

@@ -19,16 +19,11 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -338,21 +333,5 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()