mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[kernel][perf] support uncontiguous input for rms_norm kernel (#28103)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: izhuhaoran <izhuhaoran@qq.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -117,3 +117,24 @@
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
|
||||
switch (NUM_DIMS) { \
|
||||
case 2: { \
|
||||
constexpr int tensor_rank = 2; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr int tensor_rank = 3; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
constexpr int tensor_rank = 4; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
|
||||
}
|
||||
|
||||
@@ -10,16 +10,38 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t, int VEC_SIZE>
|
||||
template <typename scalar_t, int VEC_SIZE, int NUM_DIMS>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride_d2, // input.stride(-2)
|
||||
const int64_t input_stride_d3, // input.stride(-3)
|
||||
const int64_t input_stride_d4, // input.stride(-4)
|
||||
const int64_t input_shape_d2, // input.size(-2)
|
||||
const int64_t input_shape_d3, // input.size(-3)
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
const scalar_t* input_row;
|
||||
if constexpr (NUM_DIMS == 2) {
|
||||
// 2D for layernorm normal case [batch_size, hidden]
|
||||
input_row = input + blockIdx.x * input_stride_d2;
|
||||
} else if constexpr (NUM_DIMS == 3) {
|
||||
// 3D for q/k norm [batch_size, num_heads, head_size]
|
||||
int batch_idx = blockIdx.x / input_shape_d2;
|
||||
int head_idx = blockIdx.x % input_shape_d2;
|
||||
input_row =
|
||||
input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
|
||||
} else if constexpr (NUM_DIMS == 4) {
|
||||
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
|
||||
int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2);
|
||||
int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2);
|
||||
int seq_idx = remaining / input_shape_d2;
|
||||
int head_idx = remaining % input_shape_d2;
|
||||
input_row = input + batch_idx * input_stride_d4 +
|
||||
seq_idx * input_stride_d3 + head_idx * input_stride_d2;
|
||||
}
|
||||
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
@@ -164,38 +186,44 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
if (input.stride(-1) != 1) {
|
||||
input = input.contiguous();
|
||||
}
|
||||
TORCH_CHECK(input.stride(-1) == 1);
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
|
||||
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
|
||||
// Instead, we use a 2d view to get the second-innermost stride.
|
||||
// That way the dimensions (except the last one) can be arbitrarily permuted.
|
||||
torch::Tensor input_view = input.view({-1, hidden_size});
|
||||
|
||||
int num_tokens = input_view.numel() / hidden_size;
|
||||
int64_t input_stride = input_view.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
int num_dims = input.dim();
|
||||
int64_t input_stride_d2 = input.stride(-2);
|
||||
int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
|
||||
int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
|
||||
int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
|
||||
int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;
|
||||
|
||||
// For large num_tokens, use smaller blocks to increase SM concurrency.
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 grid(num_tokens);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input_view.scalar_type(), "rms_norm_kernel", [&] {
|
||||
const int calculated_vec_size =
|
||||
std::gcd(16 / sizeof(scalar_t), hidden_size);
|
||||
const int block_size =
|
||||
std::min(hidden_size / calculated_vec_size, max_block_size);
|
||||
dim3 block(block_size);
|
||||
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
|
||||
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
VLLM_DISPATCH_RANK234(num_dims, [&] {
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
const int calculated_vec_size =
|
||||
std::gcd(16 / sizeof(scalar_t), hidden_size);
|
||||
const int block_size =
|
||||
std::min(hidden_size / calculated_vec_size, max_block_size);
|
||||
dim3 block(block_size);
|
||||
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
|
||||
vllm::rms_norm_kernel<scalar_t, vec_size, tensor_rank>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
input_stride_d2, input_stride_d3, input_stride_d4,
|
||||
input_shape_d2, input_shape_d3, weight.data_ptr<scalar_t>(),
|
||||
epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
|
||||
@@ -328,10 +328,7 @@ def rotary_embedding(
|
||||
def rms_norm(
|
||||
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
|
||||
) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||
# If removed, also need to remove contiguous in MatcherRMSNorm
|
||||
input_contiguous = input.contiguous()
|
||||
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
|
||||
torch.ops._C.rms_norm(out, input, weight, epsilon)
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
|
||||
@@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
# TODO: support non-contiguous input for RMSNorm and remove this
|
||||
input_contiguous = input.contiguous()
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input_contiguous,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user