Files
vllm-anthropic/vllm/v1/attention/backends/mla/common.py
2025-10-16 21:40:25 +00:00

1930 lines
74 KiB
Python
Executable File

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
# MLA Common Components
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
## More Extent Definitions:
C Context length, `Skv - Sq`
H hidden size
N number of attention heads
Lq latent dimension for Q 1536 in DSV3
Lkv latent dimension for K/V 512 in DSV3
P nope dimension, no rope. 128 in DSV3
R rope dimension, goes through rope. 64 in DSV3
V V head dim. 128 in DSV3
## Vector/Matrix Definitions
h_t hidden states (input to attention) shape [Sq, H]
q_c latent/compressed Q shape [Sq, Lq]
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
q_pe uncompressed Q (rope) shape [Sq, N, R]
kv_c latent/compressed KV shape [Skv, Lkv]
k_pe decoupled k position embeddings shape [Skv, R]
new_kv_c new kv_c from current iter shape [Sq, Lkv]
new_k_pe new k_pe from current iter shape [Sq, R]
cache_kv_c cached k_c from previous iters shape [C, Lkv]
cache_k_pe cached k_pe from previous iters shape [C, R]
W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "_forward_prefill"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
// spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Runtime
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R
// V headdim = Lkv
// spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
// V headdim = V
// curr_o shape [Sq, N, V]
// curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
new_v,
casual=True,
return_softmax_lse=True
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
"""
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import ClassVar, Generic, TypeVar
import torch
from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod,
)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
class QueryLenSupport(Enum):
"""Defines the level of query length support for an attention backend's
decode pipeline.
- SINGLE_ONLY: Decode pipeline only supports single-token queries
(query_len=1)
- UNIFORM: Decode pipeline supports uniform multi-token queries
(all requests must have same query_len > 1)
- VARLEN: Decode pipeline supports variable-length queries
(mixed query lengths in same batch)
"""
SINGLE_ONLY = "single_only"
UNIFORM = "uniform"
VARLEN = "varlen"
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
try:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
flashinfer_available = True
except ImportError:
BatchPrefillWithRaggedKVCacheWrapper = object
flashinfer_available = False
def is_rocm_aiter_fp8bmm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP8BMM
and envs.VLLM_ROCM_USE_AITER
)
if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
)
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800
class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@dataclass
class MLACommonPrefillMetadata:
"""Prefill Specific Metadata"""
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
# for mla DCP
cp_chunk_seq_lens: list[list[int]] | None = None
origin_context_lens: list[int] | None = None
cp_cu_seq_lens: torch.Tensor | None = None
chunk_size: int | None = None
cu_seq_lens_lst: list[list[int]] | None = None
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
chunked_context: ChunkedContextMetadata | None = None
@dataclass
class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field(
default_factory=list
)
@dataclass
class CudnnPrefillMetadata(MLACommonPrefillMetadata):
class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
seq_lens: torch.Tensor
query_seq_lens: torch.Tensor | None = None
cudnn_workspace: torch.Tensor | None = None
@dataclass
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
dcp_tot_seq_lens: torch.Tensor | None
D = TypeVar("D", bound=MLACommonDecodeMetadata)
@dataclass
class MLACommonMetadata(Generic[D]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# The dimension of the attention heads
head_dim: int | None = None
decode: D | None = None
prefill: (
MLACommonPrefillMetadata
| FlashInferPrefillMetadata
| CudnnPrefillMetadata
| None
) = None
def __post_init__(self):
if self.head_dim is not None:
MLACommonBackend.validate_head_size(self.head_dim)
M = TypeVar("M", bound=MLACommonMetadata)
A = TypeVar("A")
def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
return (
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
and flashinfer_available
and not envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100)
)
def use_cudnn_prefill() -> bool:
return (
flashinfer_available
and envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100)
and has_nvidia_artifactory()
)
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Chosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# Defines the level of query length support for this backend.
# - SINGLE_ONLY: Only single-token queries (no spec decode support)
# - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
# - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
# If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
# speculative decoding is enabled.
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
# The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with
# query length <= threshold are classified as decode requests.
# Use `query_len_support` (above) to set this automatically
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
chunked_prefill_workspace_size = min(
# Try for 8 full length request or at least 4 pages per-request
max(
8 * model_config.max_model_len,
4 * scheduler_config.max_num_seqs * cache_config.block_size,
),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
64 * 1024,
)
# Enforce that we enough for at least 1 page per request
chunked_prefill_workspace_size = max(
chunked_prefill_workspace_size,
scheduler_config.max_num_seqs * cache_config.block_size,
)
return chunked_prefill_workspace_size
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[M] | None = None,
):
self.metadata_cls = (
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.device = device
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
# Don't try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size
self.chunked_prefill_workspace_size = (
self.determine_chunked_prefill_workspace_size(vllm_config)
)
if self.dcp_world_size > 1:
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# an additional kvcache allgather across the DCP group is therefore
# required, so the workspace has to be enlarged by 1/DCP relative
# to the original TP allocation.
assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
self.chunked_prefill_workspace = torch.empty(
(
self.chunked_prefill_workspace_size
+ self.chunked_prefill_workspace_size // self.dcp_world_size,
self.model_config.get_head_size(),
),
dtype=self.model_config.dtype,
device=device,
)
else:
self.chunked_prefill_workspace = torch.empty(
(
self.chunked_prefill_workspace_size,
self.model_config.get_head_size(),
),
dtype=self.model_config.dtype,
device=device,
)
self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill()
self.prefill_metadata_cls = (
FlashInferPrefillMetadata
if self._use_fi_prefill
else CudnnPrefillMetadata
if self._use_cudnn_prefill
else MLACommonPrefillMetadata
)
if self._use_fi_prefill:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
)
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
)
if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
dtype=torch.int8,
device=device,
)
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_decode
)
# Validate consistency between query_len_support and reorder_batch_threshold
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
assert self.reorder_batch_threshold == 1, (
f"reorder_batch_threshold must be 1 when query_len_support is "
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
)
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
has_context = False
if prefill.chunked_context is not None:
chunked_context = prefill.chunked_context
has_context = True
if self._fi_prefill_main is None:
self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass"
)
if has_context:
num_chunks = chunked_context.cu_seq_lens.shape[0]
# Allocate more prefill chunk wrappers if needed
if len(self._fi_prefill_chunks) < num_chunks:
for _ in range(len(self._fi_prefill_chunks), num_chunks):
self._fi_prefill_chunks.append(
BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass"
)
)
assert num_chunks <= len(self._fi_prefill_chunks)
# In MLA, the non-latent num_qo_heads == num_kv_heads
num_qo_heads = self.num_heads
num_kv_heads = num_qo_heads
# Sanity: Verify that num_kv_heads == 1 since it is latent space
assert self.kv_cache_spec.num_kv_heads == 1
# Get non-latent head_dim_qk and head_dim_vo
head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim
head_dim_vo = self.mla_dims.v_head_dim
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=True, # This is main run
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
)
# Prepare context prefills
if has_context:
for i in range(num_chunks):
kv_indptr_chunk = chunked_context.cu_seq_lens[i]
self._fi_prefill_chunks[i].plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr_chunk,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=False, # This is context run
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
)
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
"MLA only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq."
)
assert m.max_query_len <= self.reorder_batch_threshold # decode only
return self.build(0, m)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
)
)
# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
assert dcp_local_seq_lens is not None
dcp_local_seq_lens[:num_decodes] = seq_lens[
:num_decodes
] // self.dcp_world_size + (
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Note(hc): The context lengths in the perspective of dcp rank0.
cp_context_lens_cpu = torch.ceil(
context_lens_cpu.float() / self.dcp_world_size
).int()
origin_context_lens = context_lens_cpu.tolist()
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
chunked_context_metadata = None
if max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
# understand the following code
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = (
self.chunked_prefill_workspace_size // num_prefills_with_context_cpu
)
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_and_maybe_dequant_cache` kernel
# cannot handle `context_chunk_starts` that are not aligned
# to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# Note(simon): this is done in CPU because of downstream's
# of `to_list`.
chunk_starts = (
torch.arange(num_chunks, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, num_prefills)
* max_context_chunk
)
chunk_ends = torch.min(
context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk
)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
)
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
if self.dcp_world_size > 1:
# Note(hc): The above max_context_chunk already enforces
# block_size alignment, DCP just need the block_size can
# be divisible by dcp_world_size, because DCP use
# cp_gather_cache which not require `cp_chunk_starts`
# aligned to page_size.
assert max_context_chunk % self.dcp_world_size == 0
cp_max_context_chunk = max_context_chunk // self.dcp_world_size
cp_chunk_starts = (
torch.arange(num_chunks, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, num_prefills)
* cp_max_context_chunk
)
cp_chunk_ends = torch.min(
cp_context_lens_cpu.unsqueeze(0),
cp_chunk_starts + cp_max_context_chunk,
)
cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0)
cp_cu_seq_lens_cpu = torch.zeros(
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
)
torch.cumsum(
cp_chunk_seq_lens,
dim=1,
out=cp_cu_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata_cls = (
CudnnPrefillMetadata.ChunkedContextMetadata
if self._use_cudnn_prefill
else MLACommonPrefillMetadata.ChunkedContextMetadata
)
if self.dcp_world_size > 1:
chunked_context_metadata = chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
origin_context_lens=origin_context_lens,
cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True),
chunk_size=max_context_chunk,
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
)
else:
chunked_context_metadata = chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
if self._use_cudnn_prefill:
chunked_context_metadata.seq_lens = chunk_seq_lens
assert (
max(chunked_context_metadata.max_seq_lens)
<= self.chunked_prefill_workspace_size
)
prefill_metadata = self.prefill_metadata_cls(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
)
if self._use_cudnn_prefill:
assert isinstance(prefill_metadata, CudnnPrefillMetadata)
prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
)
prefill_metadata.cudnn_workspace = self.cudnn_workspace
decode_metadata = None
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=dcp_local_seq_lens[:num_decodes]
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
else seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
if self.dcp_world_size > 1
else None,
)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=max_seq_len,
num_actual_tokens=num_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
prefill=prefill_metadata,
decode=decode_metadata,
)
if self._use_fi_prefill and num_prefills > 0:
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill)
return attn_metadata
def reorg_kvcache(
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
Args:
cp_chunk_seq_lens_lst: chunk context lengths under CP.
origin_context_lens: origin full context lengths under CP.
cp_world_size: CP size.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: equals to max_context_chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(
cp_chunk_seq_lens_lst, origin_context_lens
):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx
)
cp_target_rank = (chunk_context_len - 1) % cp_world_size
cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ real_cp_chunk_seq_len
]
k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ real_cp_chunk_seq_len
]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += real_cp_chunk_seq_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += cp_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe
# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
q_lora_rank: int | None,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
indexer=None,
q_pad_num_heads: int | None = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype()
)
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
W_V, dtype=current_platform.fp8_dtype()
)
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)
for m in pre_compilation_list:
x = torch.empty(
(self.W_K.shape[0], m, self.W_K.shape[2]),
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
x = torch.empty(
(self.W_V.shape[0], m, self.W_V.shape[2]),
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if use_flashinfer_prefill():
logger.debug_once("Using FlashInfer prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
elif use_cudnn_prefill():
logger.debug_once("Using CUDNN prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
self._pad_v = False
else: # Use FlashAttention
logger.debug_once("Using FlashAttention prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
# Handle the differences between the flash_attn_varlen from
# flash_attn and the one from vllm_flash_attn. The former is used on
# RoCM and the latter has an additional parameter to control
# FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9
)
self.dcp_world_size: int | None = None
self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0
)
if is_vllm_fa:
kwargs["return_softmax_lse"] = return_softmax_lse
else:
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if vllm_is_batch_invariant():
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results
lse = None
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
return attn_out, lse
return attn_out
def _run_prefill_new_tokens_fa(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.query_start_loc,
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=return_softmax_lse,
)
def _run_prefill_new_tokens_fi(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
ret = prefill.prefill_main.run(
q=q,
k=k,
v=v,
return_lse=return_softmax_lse,
)
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
def _run_prefill_new_tokens_cudnn(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
v_cache=v,
scale=self.scale,
workspace_buffer=prefill.cudnn_workspace,
max_token_per_sequence=prefill.max_query_len,
max_sequence_kv=prefill.max_query_len,
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
causal=True,
# Do not support False for now
return_lse=True,
# Indicates actual_seq_lens are on GPU or CPU.
is_cuda_graph_compatible=True,
)
if return_softmax_lse:
return output, lse
return output
def _run_prefill_context_chunk_fa(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
assert prefill.chunked_context is not None
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
def _run_prefill_context_chunk_fi(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
assert isinstance(prefill, FlashInferPrefillMetadata)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
v=v,
return_lse=True,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
v_cache=v,
scale=self.scale,
workspace_buffer=prefill.cudnn_workspace,
max_token_per_sequence=prefill.max_query_len,
max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view(
-1, 1, 1, 1
),
causal=False,
return_lse=True,
# Indicates actual_seq_lens are on GPU or CPU.
is_cuda_graph_compatible=True,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype()
)
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
W_V, dtype=current_platform.fp8_dtype()
)
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)
for m in pre_compilation_list:
x = torch.empty(
(self.W_K.shape[0], m, self.W_K.shape[2]),
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
x = torch.empty(
(self.W_V.shape[0], m, self.W_V.shape[2]),
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
chunk_idx=i,
q=q,
k=k,
v=v,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _context_parallel_compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
assert prefill_metadata.chunked_context.origin_context_lens is not None
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
assert prefill_metadata.chunked_context.chunk_size is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# workspace
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
# |<- use for loca_gather ->|<--------- use for allgather -------->|
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0]
assert toks <= allgather_offset
local_gathered_kvcache = workspace[:toks]
cur_allgather_workspace = workspace[
allgather_offset : allgather_offset * (1 + dcp_world_size)
]
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size]
cur_allgather_kvcache.copy_(
get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
)
assert (
cur_allgather_kvcache.shape[-1]
== self.kv_lora_rank + self.qk_rope_head_dim
)
allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze(
1
).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[
i
],
origin_context_lens=prefill_metadata.chunked_context.origin_context_lens,
cp_world_size=dcp_world_size,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
chunk_idx=i,
q=q,
k=k,
v=v,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q,
k=k,
v=v,
return_softmax_lse=has_context,
)
if has_context:
suffix_output, suffix_lse = output
if self.dcp_world_size > 1:
context_output, context_lse = (
self._context_parallel_compute_prefill_context(
q,
kv_c_and_k_pe_cache,
attn_metadata,
k_scale=None,
dcp_world_size=self.dcp_world_size,
)
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
output = torch.empty_like(suffix_output)
merge_attn_states(
output=output,
prefix_output=context_output,
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output[..., : v.shape[-1]]
return output.flatten(start_dim=-2)
@abstractmethod
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: M,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for MLACommonImpl"
)
if attn_metadata is None:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_ = torch.empty(
(
self.chunked_prefill_workspace_size,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
),
device=k_c_normed.device,
dtype=k_c_normed.dtype,
)
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
if self.dcp_world_size is None:
self.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8")
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if fp8_attention:
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
layer._k_scale,
)
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
decode_pe_padded.resize_((B, N, L))
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded
if is_rocm_aiter_fp8bmm_enabled():
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(
decode_q_nope,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True,
)
else:
# Pads the head_dim if necessary (for the underlying kernel)
N, B, P = decode_q_nope.shape
_, _, L = self.W_UK_T.shape
if self.q_pad_num_heads is not None:
decode_ql_nope = decode_q_nope.new_empty(
(self.q_pad_num_heads, B, L)
)
decode_ql_nope.resize_((N, B, L))
else:
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
if fp8_attention:
ql_nope_shape = decode_ql_nope.shape
decode_ql_nope, _ = ops.scaled_fp8_quant(
decode_ql_nope.reshape(
[ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
),
layer._q_scale,
)
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
q_pe_shape = decode_q_pe.shape
decode_q_pe, _ = ops.scaled_fp8_quant(
decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
layer._q_scale,
)
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
# decode_q do allgather in head dim.
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
# call decode attn
attn_out, lse = self._forward_decode(
decode_q, kv_cache, attn_metadata, layer
)
# recorect dcp attn_out with lse.
if self.dcp_world_size > 1:
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
# v_up projection
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
return output_padded