mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Kernel] Optimize rms_norm kernel (#27931)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -88,3 +88,32 @@
|
||||
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
|
||||
switch (VEC_SIZE) { \
|
||||
case 16: { \
|
||||
constexpr int vec_size = 16; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 8: { \
|
||||
constexpr int vec_size = 8; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
constexpr int vec_size = 4; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr int vec_size = 2; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
constexpr int vec_size = 1; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, int VEC_SIZE>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
@@ -21,7 +21,6 @@ __global__ void rms_norm_kernel(
|
||||
float variance = 0.0f;
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
@@ -45,10 +44,20 @@ __global__ void rms_norm_kernel(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
scalar_t* out_row = out + blockIdx.x * hidden_size;
|
||||
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
|
||||
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
|
||||
auto* v_out = reinterpret_cast<vec_n_t<scalar_t, VEC_SIZE>*>(out_row);
|
||||
for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) {
|
||||
vec_n_t<scalar_t, VEC_SIZE> dst;
|
||||
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[i];
|
||||
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
|
||||
}
|
||||
v_out[i] = dst;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
int num_tokens = input_view.numel() / hidden_size;
|
||||
int64_t input_stride = input_view.stride(-2);
|
||||
|
||||
// For large num_tokens, use smaller blocks to increase SM concurrency.
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input_view.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
const int calculated_vec_size =
|
||||
std::gcd(16 / sizeof(scalar_t), hidden_size);
|
||||
const int block_size =
|
||||
std::min(hidden_size / calculated_vec_size, max_block_size);
|
||||
dim3 block(block_size);
|
||||
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
|
||||
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
template <typename scalar_t, typename fp8_type, int VEC_SIZE>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
@@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
@@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
// invert scale to avoid division
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
|
||||
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
|
||||
for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) {
|
||||
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[idx];
|
||||
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[idx];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
|
||||
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
// For large num_tokens, use smaller blocks to increase SM concurrency.
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
const int calculated_vec_size =
|
||||
std::gcd(16 / sizeof(scalar_t), hidden_size);
|
||||
const int block_size =
|
||||
std::min(hidden_size / calculated_vec_size, max_block_size);
|
||||
dim3 block(block_size);
|
||||
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t,
|
||||
vec_size>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user