[Kernel] Add cuda kernel for gpt_oss activation (#22951)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-17 13:03:24 +08:00
committed by GitHub
parent 87f48623a5
commit 4d4061b6e7
9 changed files with 157 additions and 42 deletions

View File

@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
}
}
template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
float alpha, float limit) {
// clamp gate: min=None, max=limit
const float gate_f = (float)gate;
const float clamped_gate = gate_f > limit ? limit : gate_f;
// clamp up: min=-limit, max=limit
const float up_f = (float)up;
const float clamped_up =
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
// glu = gate * sigmoid(gate * alpha)
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
const float glu = clamped_gate * sigmoid_val;
// (up + 1) * glu
return (T)((clamped_up + 1.0f) * glu);
}
template <typename scalar_t,
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
const float)>
__global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float alpha, const float limit) {
const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores.
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
}
}
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
PARAM); \
});
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, ALPHA, \
LIMIT); \
});
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
double alpha, double limit) {
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
}
namespace vllm {
// Element-wise activation kernel template.

View File

@@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
double threshold);
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
double alpha = 1.702, double limit = 7.0);
void gelu_new(torch::Tensor& out, torch::Tensor& input);

View File

@@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()");
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);

View File

@@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, MulAndSilu,
NewGELU, QuickGELU,
SiluAndMul)
SiluAndMul, SwigluOAIAndMul)
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -25,7 +25,15 @@ CUDA_DEVICES = [
@pytest.mark.parametrize(
"activation",
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
[
"silu_and_mul",
"mul_and_silu",
"gelu",
"gelu_tanh",
"fatrelu",
"swigluoai_and_mul",
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -59,18 +67,43 @@ def test_act_and_mul(
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
elif activation == "swigluoai_and_mul":
layer = SwigluOAIAndMul()
fn = torch.ops._C.swigluoai_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
if activation == "swigluoai_and_mul":
rtol = {
#For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16:
2e-3,
torch.bfloat16:
2e-2,
torch.float:
1.3e-6
}
def _get_rtol(output) -> float:
return rtol[output.dtype]
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=_get_rtol(out))
else:
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
elif activation == "swigluoai_and_mul":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
opcheck(fn, (out, x))

View File

@@ -239,6 +239,35 @@ class GeluAndMul(CustomOp):
return f'approximate={repr(self.approximate)}'
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
super().__init__()
self.alpha = alpha
self.limit = limit
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
return out
def extra_repr(self) -> str:
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
@@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp):
return torch.square(F.relu(x))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
#TODO : implement cuda kenrels
return self.forward_native(x)
@@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"geglu": lambda: GeluAndMul(),
"gelu":
lambda: GeluAndMul(),
"silu":
lambda: SiluAndMul(),
"geglu":
lambda: GeluAndMul(),
"swigluoai":
lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
})

View File

@@ -161,25 +161,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
elif activation == "swiglu_oai":
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
@torch.compile(dynamic=True)
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
return (up + 1) * glu
intermediate_cache2 = swiglu_oai(intermediate_cache1)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
else:
raise ValueError(f"Unsupported activation: {activation}. "
"Only silu and swiglu_oai activations are supported.")
"Only silu and swigluoai activations are supported.")
if expert_map is not None:
intermediate_cache3.zero_()

View File

@@ -1621,17 +1621,6 @@ def fused_experts_impl(
block_shape=block_shape,
B_bias=w1_bias)
# TODO fused kernel
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
torch.ops._C.silu_and_mul(intermediate_cache2,
@@ -1639,13 +1628,16 @@ def fused_experts_impl(
elif activation == "gelu" and is_act_and_mul:
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "swigluoai" and is_act_and_mul:
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
# Activation function without multiplication
elif activation == "silu":
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
elif activation == "swiglu_oai":
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")

View File

@@ -61,14 +61,14 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: str = "swiglu_oai",
activation: str = "swigluoai",
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None):
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
or scoring_func != "softmax" or activation != "swiglu_oai"
or scoring_func != "softmax" or activation != "swigluoai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)

View File

@@ -159,7 +159,7 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False,
has_bias=True,
activation="swiglu_oai")
activation="swigluoai")
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)