Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul
611034eb74 Update docs/source/en/optimization/attention_backends.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-18 23:31:40 +05:30
Sayak Paul
052d5e6d5f Update attention_backends.md 2026-03-18 15:43:53 +05:30
5 changed files with 3 additions and 43 deletions

View File

@@ -35,7 +35,7 @@ The [`~ModelMixin.set_attention_backend`] method iterates through all the module
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [`kernels`](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
> [!NOTE]
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
> FlashAttention-3 requires Ampere GPUs at a minimum.
```py
import torch
@@ -143,7 +143,6 @@ Refer to the table below for a complete list of available attention backends and
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -229,7 +229,6 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -359,11 +358,6 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -527,7 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -2683,37 +2676,6 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
@@ -342,6 +343,7 @@ class HeliosRotaryPosEmbed(nn.Module):
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)

View File

@@ -720,7 +720,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,

View File

@@ -35,8 +35,6 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
self.post_init()
def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output