[Deepseek v3.2] Remove extra logics in indexer (#26465)

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: Lain <siyuanf@nvidia.com>
Co-authored-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Lain
2025-10-21 16:34:03 -07:00
committed by GitHub
parent 6c2eef5a5d
commit 09a7e6f617
5 changed files with 141 additions and 40 deletions

View File

@@ -101,6 +101,10 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seq_lens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);

View File

@@ -54,15 +54,10 @@ static inline __device__ uint16_t extractBinIdx(float x) {
return 511 - (tmp.u16 >> 7);
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
const int* rowEnds, int* outIndices,
int stride0, int stride1) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048>
__device__ void topKPerRowJob(const float* logits, const int rowStart,
const int rowEnd, const int rowIdx,
int* outIndices, int stride0, int stride1) {
// The number of elements per thread for the final top-k sort.
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
// The class to sort the elements during the final top-k sort.
@@ -108,10 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
// Shared memory counter to register the candidates for the final phase.
__shared__ int smemFinalDstIdx[1];
// The row computed by this block.
int rowIdx = blockIdx.x;
// The range of logits within the row.
int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
// The length of the row.
int rowLen = rowEnd - rowStart;
@@ -260,6 +251,49 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
}
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
const int* rowEnds, int* outIndices,
int stride0, int stride1) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
// The row computed by this block.
int rowIdx = blockIdx.x;
// The range of logits within the row.
int rowStart = rowStarts[rowIdx];
int rowEnd = rowEnds[rowIdx];
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
int* outIndices, int stride0,
int stride1, int next_n) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
// The row computed by this block.
int rowIdx = blockIdx.x;
// The range of logits within the row.
int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
}
} // namespace vllm
void apply_repetition_penalties_(
@@ -303,6 +337,20 @@ void apply_repetition_penalties_(
});
}
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) {
// Compute the results on the device.
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::topKPerRowDecode<kNumThreadsPerBlock>
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(next_n));
}
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) {

View File

@@ -189,6 +189,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int stride1) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, int numRows, "
"int stride0, int stride1) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(

View File

@@ -10,6 +10,8 @@ from vllm.platforms import current_platform
# Test parameters
NUM_ROWS = [1, 32, 2050]
TOP_K_VALUES = [2048]
BATCH_SIZE = [1, 2, 4, 2048, 4096]
NEXT_N = [1, 2, 4, 8]
def create_random_logits(
@@ -114,7 +116,7 @@ def test_top_k_per_row(
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
# Create output tensors
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
# Run CUDA implementation
torch.ops._C.top_k_per_row(
@@ -138,3 +140,59 @@ def test_top_k_per_row(
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
top_k: int,
batch_size: int,
next_n: int,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
"""
torch.set_default_device("cuda:0")
# Create test data
num_rows = batch_size * next_n
vocab_size = 20000
seq_lens = torch.randint(
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
)
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
# Run CUDA implementation
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
seq_lens,
indices,
num_rows,
logits.stride(0),
logits.stride(1),
)
torch.cuda.synchronize()
# Run reference implementation
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
mask_lo = torch_indices >= 0
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
mask = mask_lo & mask_hi
torch_indices = torch_indices.masked_fill(~mask, -1)
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"

View File

@@ -580,9 +580,9 @@ def sparse_attn_indexer(
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
)
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row(
logits,
chunk.cu_seqlen_ks,
@@ -592,9 +592,6 @@ def sparse_attn_indexer(
logits.stride(0),
logits.stride(1),
)
topk_indices_buffer[
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
] = topk_indices.to(dtype=torch.int32)
if has_decode:
decode_metadata = attn_metadata.decode
@@ -628,26 +625,14 @@ def sparse_attn_indexer(
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
next_n_offset = (
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
% next_n
)
index_end_pos = (
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
).unsqueeze(1)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
)
torch.ops._C.top_k_per_row(
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
index_end_pos.to(dtype=torch.int32, device=logits.device),
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
@@ -660,9 +645,9 @@ def sparse_attn_indexer(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices.to(dtype=torch.int32)
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices
)
return topk_indices_buffer