mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -35,10 +35,12 @@
|
||||
CHECK_TH_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define FINAL_MASK 0xffffffff
|
||||
#ifdef USE_ROCM
|
||||
#define FINAL_MASK 0xffffffffffffffffULL
|
||||
#else
|
||||
#define FINAL_MASK 0xffffffff
|
||||
#endif
|
||||
|
||||
// TODO: suport for AMD ROCM platform
|
||||
#ifndef USE_ROCM
|
||||
namespace tensorrt_llm::common {
|
||||
template <typename T, int num>
|
||||
struct packed_as;
|
||||
@@ -60,7 +62,7 @@ struct packed_as<uint, 4> {
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||
return val;
|
||||
@@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
|
||||
int64_t const* position_ids, // Position IDs for RoPE
|
||||
int const num_tokens // Number of tokens
|
||||
) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||
static_assert(Converter::exists,
|
||||
@@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
|
||||
{
|
||||
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
|
||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_packed_elems; i++) {
|
||||
// Interpret the generic vector chunk as the specific packed type
|
||||
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
|
||||
@@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
|
||||
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
|
||||
|
||||
// Normalize elements
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
int dim = laneId * numElemsPerThread + i;
|
||||
float weight = isQ ? Converter::convert(q_weight[dim])
|
||||
@@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(
|
||||
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
@@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
|
||||
__syncwarp();
|
||||
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
||||
// values.
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
||||
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
|
||||
if (laneId < 16) {
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
@@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
|
||||
{
|
||||
vec_T vec;
|
||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_packed_elems; i++) {
|
||||
// Convert from float2 back to the specific packed type
|
||||
T2_in packed_val = Converter::convert(
|
||||
@@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
|
||||
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
// Borrowed from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||
if (interleave) { \
|
||||
const bool INTERLEAVE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool INTERLEAVE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
// Borrowed from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||
if (interleave) { \
|
||||
const bool INTERLEAVE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool INTERLEAVE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
template <typename scalar_t_in, typename scalar_t_cache>
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
@@ -413,6 +415,4 @@ void fused_qk_norm_rope(
|
||||
stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#endif // not USE_ROCM
|
||||
}
|
||||
@@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Function for fused QK Norm and RoPE
|
||||
ops.def(
|
||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||
@@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||
"bool is_neox, Tensor position_ids) -> ()");
|
||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
|
||||
@@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
// ROCm 7.0+ supports bfloat16
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
@@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
|
||||
// defined(USE_ROCM)
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
|
||||
@@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="Only test on cuda platform",
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
|
||||
@@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="fused_qk_norm_rope custom op requires cuda platform",
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
|
||||
@@ -184,10 +184,10 @@ class PassConfig:
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
|
||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
|
||||
logger.warning_once(
|
||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||
"CUDA. The fusion will be disabled."
|
||||
"CUDA or ROCm. The fusion will be disabled."
|
||||
)
|
||||
self.enable_qk_norm_rope_fusion = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user