[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-11-12 05:01:14 -08:00
committed by GitHub
parent c5f10cc139
commit edb59a9470
6 changed files with 37 additions and 38 deletions

View File

@@ -35,10 +35,12 @@
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define FINAL_MASK 0xffffffff
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffff
#endif
// TODO: suport for AMD ROCM platform
#ifndef USE_ROCM
namespace tensorrt_llm::common {
template <typename T, int num>
struct packed_as;
@@ -60,7 +62,7 @@ struct packed_as<uint, 4> {
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
@@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens
) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return;
} else {
#endif
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
static_assert(Converter::exists,
@@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
{
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Interpret the generic vector chunk as the specific packed type
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
@@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
// Normalize elements
#pragma unroll
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? Converter::convert(q_weight[dim])
@@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(
if constexpr (interleave) {
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
@@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
__syncwarp();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) {
elements2[i] = -elements2[i];
}
@@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
{
vec_T vec;
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Convert from float2 back to the specific packed type
T2_in packed_val = Converter::convert(
@@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
}
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRope(void* qkv, int const num_tokens,
@@ -413,6 +415,4 @@ void fused_qk_norm_rope(
stream);
});
});
}
#endif // not USE_ROCM
}

View File

@@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
#ifndef USE_ROCM
// Function for fused QK Norm and RoPE
ops.def(
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
@@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#endif
// Apply repetition penalties to logits in-place
ops.def(

View File

@@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
// ROCm 7.0+ supports bfloat16
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
@@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
// defined(USE_ROCM)
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))

View File

@@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
@pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Only test on cuda platform",
not current_platform.is_cuda_alike(),
reason="Only test on cuda and rocm platform",
)
def test_qk_norm_rope_fusion(
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype

View File

@@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="fused_qk_norm_rope custom op requires cuda platform",
not current_platform.is_cuda_alike(),
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("dtype", DTYPES)

View File

@@ -184,10 +184,10 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled."
"CUDA or ROCm. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False