[Kernel] Optimize rms_norm kernel (#27931)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2025-11-11 10:02:23 -08:00
committed by GitHub
parent 684f254585
commit 6c3c0f8235
3 changed files with 86 additions and 25 deletions

View File

@@ -88,3 +88,32 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}

View File

@@ -10,7 +10,7 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
template <typename scalar_t, int VEC_SIZE>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
@@ -21,7 +21,6 @@ __global__ void rms_norm_kernel(
float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride;
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
@@ -45,10 +44,20 @@ __global__ void rms_norm_kernel(
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
scalar_t* out_row = out + blockIdx.x * hidden_size;
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
auto* v_out = reinterpret_cast<vec_n_t<scalar_t, VEC_SIZE>*>(out_row);
for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) {
vec_n_t<scalar_t, VEC_SIZE> dst;
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[i];
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
}
v_out[i] = dst;
}
}
@@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
int num_tokens = input_view.numel() / hidden_size;
int64_t input_stride = input_view.stride(-2);
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input_view.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
});
});
}

View File

@@ -18,7 +18,7 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type>
template <typename scalar_t, typename fp8_type, int VEC_SIZE>
__global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
@@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel(
const scalar_t* input_row = input + blockIdx.x * input_stride;
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
@@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) {
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[idx];
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[idx];
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}
}
@@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t,
vec_size>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
});
});
});
}