mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Perf][Deepseek] optimize gather_and_maybe_dequant_cache kernel's perf for extremely long sequence (#28029)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
11
csrc/cache.h
11
csrc/cache.h
@@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
|
||||
int64_t num_tokens, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
|
||||
@@ -905,91 +905,79 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt,
|
||||
int ENTRY_SIZE, int CTA_SIZE>
|
||||
__global__ void gather_and_maybe_dequant_cache(
|
||||
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
|
||||
const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
|
||||
const int32_t num_tokens, const int32_t block_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const float* __restrict__ scale,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
|
||||
using ltype = vllm::vec_n_t<cache_t, vec_size>;
|
||||
using stype = vllm::vec_n_t<scalar_t, vec_size>;
|
||||
// We are adding this for code readability which will be optimized out when
|
||||
// build in release.
|
||||
assert(CTA_SIZE == blockDim.x);
|
||||
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = cu_seq_lens[bid];
|
||||
const int32_t seq_end = cu_seq_lens[bid + 1];
|
||||
const int32_t seq_len = seq_end - seq_start;
|
||||
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
|
||||
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
|
||||
#pragma unroll
|
||||
for (int token_id = blockIdx.x; token_id < num_tokens;
|
||||
token_id += gridDim.x) {
|
||||
int64_t batch_id = token_to_seq[token_id];
|
||||
int64_t batch_start = cu_seq_lens[batch_id];
|
||||
int64_t batch_end = cu_seq_lens[batch_id + 1];
|
||||
int32_t batch_offset = token_id - batch_start;
|
||||
|
||||
const int32_t split_start = split * split_blocks;
|
||||
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
|
||||
if (token_id >= batch_end) return;
|
||||
int32_t offset = 0;
|
||||
if (seq_starts != nullptr) {
|
||||
offset = seq_starts[batch_id];
|
||||
}
|
||||
batch_offset += offset;
|
||||
int32_t block_table_id = batch_offset / block_size;
|
||||
int32_t slot_id = batch_offset % block_size;
|
||||
int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
|
||||
int32_t block_id = block_table[block_table_offset];
|
||||
int64_t cache_offset =
|
||||
block_id * cache_block_stride + slot_id * cache_entry_stride;
|
||||
constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
|
||||
scalar_t* dst_ = dst + token_id * dst_entry_stride;
|
||||
cache_t* src_ = const_cast<cache_t*>(src_cache) + cache_offset;
|
||||
|
||||
const bool is_active_split = (split_start < tot_blocks);
|
||||
const bool is_last_split = (split_end == tot_blocks);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
int32_t full_blocks_end = split_end;
|
||||
int32_t partial_block_size = 0;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch.
|
||||
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
|
||||
// page_size)
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = 0;
|
||||
if (seq_starts != nullptr) {
|
||||
offset = seq_starts[bid] / block_size;
|
||||
}
|
||||
const int32_t* batch_block_table = block_table + batch_offset + offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths.
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
if (is_last_split) {
|
||||
partial_block_size = seq_len % block_size;
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const cache_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
|
||||
#pragma unroll
|
||||
for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
_dst[i] = static_cast<scalar_t>(_src[i]);
|
||||
reinterpret_cast<stype*>(dst_)[idx] =
|
||||
static_cast<stype>(reinterpret_cast<ltype*>(src_)[idx]);
|
||||
} else {
|
||||
_dst[i] =
|
||||
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
|
||||
ltype loaded_val = reinterpret_cast<ltype*>(src_)[idx];
|
||||
stype store_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vec_size; ++j) {
|
||||
store_val.val[j] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(
|
||||
loaded_val.val[j], *scale);
|
||||
}
|
||||
reinterpret_cast<stype*>(dst_)[idx] = store_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const auto loop_end =
|
||||
std::min((int64_t)full_blocks_end, block_table_stride - offset);
|
||||
for (int pid = split_start; pid < loop_end; ++pid) {
|
||||
auto block_id = batch_block_table[pid];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
|
||||
for (int eid = 0; eid < block_size; ++eid) {
|
||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||
block_dst_ptr + eid * dst_entry_stride);
|
||||
}
|
||||
}
|
||||
|
||||
if (partial_block_size) {
|
||||
if (offset + full_blocks_end < block_table_stride) {
|
||||
auto block_id = batch_block_table[full_blocks_end];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr =
|
||||
dst + full_blocks_end * block_size * dst_entry_stride;
|
||||
for (int eid = 0; eid < partial_block_size; ++eid) {
|
||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||
block_dst_ptr + eid * dst_entry_stride);
|
||||
// process tail
|
||||
constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
|
||||
dst_ = dst_ + ENTRY_SIZE - tail_cnt;
|
||||
src_ = src_ + ENTRY_SIZE - tail_cnt;
|
||||
#pragma unroll
|
||||
for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst_[idx] = static_cast<scalar_t>(src_[idx]);
|
||||
} else {
|
||||
dst_[idx] =
|
||||
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(src_[idx], *scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1001,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache(
|
||||
// SCALAR_T is the data type of the destination tensor.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
|
||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \
|
||||
thread_block_size> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
token_to_seq.data_ptr<int32_t>(), num_tokens, block_size, \
|
||||
block_table_stride, cache_block_stride, cache_entry_stride, \
|
||||
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
|
||||
seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - token_to_seq contains the back mapping from token_id to batch_id
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
|
||||
int64_t num_tokens, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t entry_size = src_cache.flatten(2, -1).size(2);
|
||||
int32_t head_dim = dst.size(-1);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
@@ -1038,6 +1030,9 @@ void gather_and_maybe_dequant_cache(
|
||||
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
||||
"seq_starts must be int32");
|
||||
}
|
||||
TORCH_CHECK(head_dim == 576,
|
||||
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
|
||||
"for better performance")
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
@@ -1055,10 +1050,9 @@ void gather_and_maybe_dequant_cache(
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size.
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
constexpr int32_t thread_block_size = 64;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(thread_block_size);
|
||||
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
@@ -695,7 +695,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
cache_ops.def(
|
||||
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
|
||||
" Tensor block_table, Tensor cu_seq_lens, "
|
||||
" int batch_size, "
|
||||
" Tensor token_to_seq, "
|
||||
" int num_tokens, "
|
||||
" str kv_cache_dtype, "
|
||||
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||
|
||||
@@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
|
||||
)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
|
||||
seq_len_tensor = torch.randint(
|
||||
max_seq_len, max_seq_len + 1, (batch_size,), device=device
|
||||
)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
|
||||
token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
@@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
token_to_seq,
|
||||
total_tokens,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
None,
|
||||
@@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
token_to_seq,
|
||||
total_tokens,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
None,
|
||||
|
||||
@@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seq_lens: torch.Tensor,
|
||||
batch_size: int,
|
||||
token_to_seq: torch.Tensor,
|
||||
num_tokens: int,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
seq_starts: torch.Tensor | None = None,
|
||||
@@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
token_to_seq,
|
||||
num_tokens,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
seq_starts,
|
||||
|
||||
@@ -340,6 +340,8 @@ class MLACommonPrefillMetadata:
|
||||
max_seq_lens: list[int]
|
||||
seq_lens: torch.Tensor
|
||||
workspace: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
chunk_total_token: list[int]
|
||||
|
||||
# for mla DCP
|
||||
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||
@@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
torch.cumsum(
|
||||
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
|
||||
)
|
||||
chunk_total_token = cu_seq_lens_cpu[:, -1]
|
||||
|
||||
max_token_num_over_chunk = chunk_total_token.max().item()
|
||||
token_to_seq_tensor_cpu = torch.zeros(
|
||||
[num_chunks, max_token_num_over_chunk], dtype=torch.int32
|
||||
)
|
||||
range_idx = torch.arange(num_prefills, dtype=torch.int32)
|
||||
for i in range(num_chunks):
|
||||
chunk_token_to_seq_tensor = torch.repeat_interleave(
|
||||
range_idx, chunk_seq_lens[i]
|
||||
)
|
||||
chunk_len = chunk_token_to_seq_tensor.shape[0]
|
||||
token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
local_context_lens_allranks = get_dcp_local_seq_lens(
|
||||
@@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
token_to_seq=token_to_seq_tensor_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
chunk_total_token=chunk_total_token.tolist(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||
@@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
token_to_seq=token_to_seq_tensor_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
chunk_total_token=chunk_total_token,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
@@ -1638,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
|
||||
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
|
||||
Reference in New Issue
Block a user