[Performance][MLA][ROCm] Remove redundant D2D copy in deepseek (#27457)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2025-11-26 12:45:28 +08:00
committed by GitHub
parent 53d7f1f601
commit d9d342d214
5 changed files with 49 additions and 41 deletions

View File

@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size) {
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint head_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
scalar_t* output_head_ptr = output + head_offset;
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
head_idx * prefix_head_stride;
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size); \
num_heads, head_size, prefix_head_stride, output_head_stride); \
}
/*@brief Merges the attention states from prefix and suffix
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
"output heads must be contiguous in memory");
TORCH_CHECK(
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
"prefix_output heads must be contiguous in memory");
TORCH_CHECK(
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
"suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();

View File

@@ -52,14 +52,13 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]

View File

@@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
@@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "

View File

@@ -20,7 +20,11 @@ def merge_attn_states(
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# We assume the output stride on num_head is not always as same as the
# `suffix_output` and `prefix_output`, as them might be padded by the attention
# backend.
prefix_head_stride = prefix_output.stride(1)
output_head_stride = output.stride(1)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
@@ -29,6 +33,8 @@ def merge_attn_states(
prefix_lse,
suffix_output,
suffix_lse,
prefix_head_stride,
output_head_stride,
head_size,
padded_head_size,
output_lse is not None,
@@ -43,6 +49,8 @@ def merge_attn_states_kernel(
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_head_stride,
output_head_stride,
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
@@ -79,15 +87,15 @@ def merge_attn_states_kernel(
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
@@ -99,7 +107,10 @@ def merge_attn_states_kernel(
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
output
+ token_idx * num_heads * output_head_stride
+ head_idx * output_head_stride
+ head_arange,
out,
mask=head_mask,
)

View File

@@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if self.is_aiter_triton_fp8_bmm_enabled:
out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
@@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
output: torch.Tensor,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
@@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._run_prefill_new_tokens(
output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q,
k=k,
@@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
if has_context:
suffix_output, suffix_lse = output
suffix_output, suffix_lse = output_prefill
if self.dcp_world_size > 1:
context_output, context_lse = (
self._context_parallel_compute_prefill_context(
@@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
output = torch.empty_like(suffix_output)
# unpad if necessary
if self._pad_v:
context_output = context_output[..., : v.shape[-1]]
suffix_output = suffix_output[..., : v.shape[-1]]
output = output.view(-1, self.num_heads, self.v_head_dim)
merge_attn_states(
output=output,
prefix_output=context_output,
@@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output[..., : v.shape[-1]]
return output.flatten(start_dim=-2)
else:
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
output.copy_(output_prefill)
@abstractmethod
def _forward_decode(
@@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
self._forward_prefill(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
)
if has_decode: