[Perf][Deepseek] optimize gather_and_maybe_dequant_cache kernel's perf for extremely long sequence (#28029)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2025-11-25 10:05:46 +08:00
committed by GitHub
parent 6f1355a1b7
commit 77e10c9cab
6 changed files with 131 additions and 105 deletions

View File

@@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt);

View File

@@ -905,91 +905,79 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt,
int ENTRY_SIZE, int CTA_SIZE>
__global__ void gather_and_maybe_dequant_cache(
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
const int32_t num_tokens, const int32_t block_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
using ltype = vllm::vec_n_t<cache_t, vec_size>;
using stype = vllm::vec_n_t<scalar_t, vec_size>;
// We are adding this for code readability which will be optimized out when
// build in release.
assert(CTA_SIZE == blockDim.x);
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
#pragma unroll
for (int token_id = blockIdx.x; token_id < num_tokens;
token_id += gridDim.x) {
int64_t batch_id = token_to_seq[token_id];
int64_t batch_start = cu_seq_lens[batch_id];
int64_t batch_end = cu_seq_lens[batch_id + 1];
int32_t batch_offset = token_id - batch_start;
const int32_t split_start = split * split_blocks;
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
if (token_id >= batch_end) return;
int32_t offset = 0;
if (seq_starts != nullptr) {
offset = seq_starts[batch_id];
}
batch_offset += offset;
int32_t block_table_id = batch_offset / block_size;
int32_t slot_id = batch_offset % block_size;
int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
int32_t block_id = block_table[block_table_offset];
int64_t cache_offset =
block_id * cache_block_stride + slot_id * cache_entry_stride;
constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
scalar_t* dst_ = dst + token_id * dst_entry_stride;
cache_t* src_ = const_cast<cache_t*>(src_cache) + cache_offset;
const bool is_active_split = (split_start < tot_blocks);
const bool is_last_split = (split_end == tot_blocks);
if (!is_active_split) return;
int32_t full_blocks_end = split_end;
int32_t partial_block_size = 0;
// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
// page_size)
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = 0;
if (seq_starts != nullptr) {
offset = seq_starts[bid] / block_size;
}
const int32_t* batch_block_table = block_table + batch_offset + offset;
// Adjust dst pointer based on the cumulative sequence lengths.
dst += seq_start * dst_entry_stride;
if (is_last_split) {
partial_block_size = seq_len % block_size;
if (partial_block_size) full_blocks_end -= 1;
}
auto copy_entry = [&](const cache_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
#pragma unroll
for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
_dst[i] = static_cast<scalar_t>(_src[i]);
reinterpret_cast<stype*>(dst_)[idx] =
static_cast<stype>(reinterpret_cast<ltype*>(src_)[idx]);
} else {
_dst[i] =
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
ltype loaded_val = reinterpret_cast<ltype*>(src_)[idx];
stype store_val;
#pragma unroll
for (int j = 0; j < vec_size; ++j) {
store_val.val[j] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(
loaded_val.val[j], *scale);
}
reinterpret_cast<stype*>(dst_)[idx] = store_val;
}
}
};
const auto loop_end =
std::min((int64_t)full_blocks_end, block_table_stride - offset);
for (int pid = split_start; pid < loop_end; ++pid) {
auto block_id = batch_block_table[pid];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
for (int eid = 0; eid < block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
if (partial_block_size) {
if (offset + full_blocks_end < block_table_stride) {
auto block_id = batch_block_table[full_blocks_end];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr =
dst + full_blocks_end * block_size * dst_entry_stride;
for (int eid = 0; eid < partial_block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
// process tail
constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
dst_ = dst_ + ENTRY_SIZE - tail_cnt;
src_ = src_ + ENTRY_SIZE - tail_cnt;
#pragma unroll
for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst_[idx] = static_cast<scalar_t>(src_[idx]);
} else {
dst_[idx] =
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(src_[idx], *scale);
}
}
}
@@ -1001,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, \
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \
thread_block_size> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
token_to_seq.data_ptr<int32_t>(), num_tokens, block_size, \
block_table_stride, cache_block_stride, cache_entry_stride, \
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - token_to_seq contains the back mapping from token_id to batch_id
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);
int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
@@ -1038,6 +1030,9 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
"for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
@@ -1055,10 +1050,9 @@ void gather_and_maybe_dequant_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size.
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(1024);
constexpr int32_t thread_block_size = 64;
dim3 grid(num_tokens);
dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;

View File

@@ -695,7 +695,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" int batch_size, "
" Tensor token_to_seq, "
" int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,

View File

@@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
seq_len_tensor = torch.randint(
max_seq_len, max_seq_len + 1, (batch_size,), device=device
)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
@@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
batch_size,
token_to_seq,
total_tokens,
kv_cache_dtype,
scale,
None,
@@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
batch_size,
token_to_seq,
total_tokens,
kv_cache_dtype,
scale,
None,

View File

@@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
token_to_seq: torch.Tensor,
num_tokens: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: torch.Tensor | None = None,
@@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
dst,
block_table,
cu_seq_lens,
batch_size,
token_to_seq,
num_tokens,
kv_cache_dtype,
scale,
seq_starts,

View File

@@ -340,6 +340,8 @@ class MLACommonPrefillMetadata:
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
token_to_seq: torch.Tensor
chunk_total_token: list[int]
# for mla DCP
padded_local_chunk_seq_lens: list[list[int]] | None = None
@@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
chunk_total_token = cu_seq_lens_cpu[:, -1]
max_token_num_over_chunk = chunk_total_token.max().item()
token_to_seq_tensor_cpu = torch.zeros(
[num_chunks, max_token_num_over_chunk], dtype=torch.int32
)
range_idx = torch.arange(num_prefills, dtype=torch.int32)
for i in range(num_chunks):
chunk_token_to_seq_tensor = torch.repeat_interleave(
range_idx, chunk_seq_lens[i]
)
chunk_len = chunk_token_to_seq_tensor.shape[0]
token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
if self.dcp_world_size > 1:
local_context_lens_allranks = get_dcp_local_seq_lens(
@@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
token_to_seq=token_to_seq_tensor_cpu.to(
device, non_blocking=True
),
chunk_total_token=chunk_total_token.tolist(),
workspace=self.chunked_prefill_workspace,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
@@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
token_to_seq=token_to_seq_tensor_cpu.to(
device, non_blocking=True
),
chunk_total_token=chunk_total_token,
workspace=self.chunked_prefill_workspace,
)
@@ -1638,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],