mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Kernel] Add cuda kernel for gpt_oss activation (#22951)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
}
|
||||
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d, const float alpha, const float limit) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
|
||||
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
|
||||
PARAM); \
|
||||
});
|
||||
|
||||
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
|
||||
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d, ALPHA, \
|
||||
LIMIT); \
|
||||
});
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
double threshold) {
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
|
||||
}
|
||||
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
double alpha, double limit) {
|
||||
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
|
||||
}
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
|
||||
@@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double threshold);
|
||||
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double alpha = 1.702, double limit = 7.0);
|
||||
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
|
||||
@@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
||||
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
|
||||
|
||||
ops.def(
|
||||
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
|
||||
"limit=7.0) "
|
||||
"-> ()");
|
||||
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
|
||||
|
||||
// GELU implementation used in GPT-2.
|
||||
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
||||
|
||||
@@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
||||
GeluAndMul, MulAndSilu,
|
||||
NewGELU, QuickGELU,
|
||||
SiluAndMul)
|
||||
SiluAndMul, SwigluOAIAndMul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@@ -25,7 +25,15 @@ CUDA_DEVICES = [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation",
|
||||
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
|
||||
[
|
||||
"silu_and_mul",
|
||||
"mul_and_silu",
|
||||
"gelu",
|
||||
"gelu_tanh",
|
||||
"fatrelu",
|
||||
"swigluoai_and_mul",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@@ -59,18 +67,43 @@ def test_act_and_mul(
|
||||
threshold = random.uniform(0, 1)
|
||||
layer = FatreluAndMul(threshold)
|
||||
fn = torch.ops._C.fatrelu_and_mul
|
||||
elif activation == "swigluoai_and_mul":
|
||||
layer = SwigluOAIAndMul()
|
||||
fn = torch.ops._C.swigluoai_and_mul
|
||||
out = layer(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
|
||||
# equivalent to the native PyTorch implementations, so we can do exact
|
||||
# comparison.
|
||||
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||
if activation == "swigluoai_and_mul":
|
||||
|
||||
rtol = {
|
||||
#For fp16, change the relative tolerance from 1e-3 to 2e-3
|
||||
torch.float16:
|
||||
2e-3,
|
||||
torch.bfloat16:
|
||||
2e-2,
|
||||
torch.float:
|
||||
1.3e-6
|
||||
}
|
||||
|
||||
def _get_rtol(output) -> float:
|
||||
return rtol[output.dtype]
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
rtol=_get_rtol(out))
|
||||
else:
|
||||
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
|
||||
# equivalent to the native PyTorch implementations, so we can do exact
|
||||
# comparison.
|
||||
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "fatrelu":
|
||||
opcheck(fn, (out, x, threshold))
|
||||
elif activation == "swigluoai_and_mul":
|
||||
opcheck(fn, (out, x, layer.alpha, layer.limit))
|
||||
else:
|
||||
opcheck(fn, (out, x))
|
||||
|
||||
|
||||
@@ -239,6 +239,35 @@ class GeluAndMul(CustomOp):
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("swigluoai_and_mul")
|
||||
class SwigluOAIAndMul(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
|
||||
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.limit = limit
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
|
||||
gate, up = x[..., ::2], x[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
|
||||
|
||||
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
@@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp):
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
#TODO : implement cuda kenrels
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"geglu": lambda: GeluAndMul(),
|
||||
"gelu":
|
||||
lambda: GeluAndMul(),
|
||||
"silu":
|
||||
lambda: SiluAndMul(),
|
||||
"geglu":
|
||||
lambda: GeluAndMul(),
|
||||
"swigluoai":
|
||||
lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
|
||||
})
|
||||
|
||||
|
||||
|
||||
@@ -161,25 +161,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, 2 * N))
|
||||
elif activation == "swiglu_oai":
|
||||
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
|
||||
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def swiglu_oai(gate_up):
|
||||
alpha = 1.702
|
||||
limit = 7.0
|
||||
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(gate * alpha)
|
||||
return (up + 1) * glu
|
||||
|
||||
intermediate_cache2 = swiglu_oai(intermediate_cache1)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, 2 * N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}. "
|
||||
"Only silu and swiglu_oai activations are supported.")
|
||||
"Only silu and swigluoai activations are supported.")
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
@@ -1621,17 +1621,6 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias)
|
||||
|
||||
# TODO fused kernel
|
||||
def swiglu_oai(gate_up):
|
||||
alpha = 1.702
|
||||
limit = 7.0
|
||||
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(gate * alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
# Activation function with multiplication
|
||||
if activation == "silu" and is_act_and_mul:
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
@@ -1639,13 +1628,16 @@ def fused_experts_impl(
|
||||
elif activation == "gelu" and is_act_and_mul:
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "swigluoai" and is_act_and_mul:
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
# Activation function without multiplication
|
||||
elif activation == "silu":
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "swiglu_oai":
|
||||
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||
f"with is_act_and_mul={is_act_and_mul}.")
|
||||
|
||||
@@ -61,14 +61,14 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "swiglu_oai",
|
||||
activation: str = "swigluoai",
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "swiglu_oai"
|
||||
or scoring_func != "softmax" or activation != "swigluoai"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class MLPBlock(torch.nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False,
|
||||
has_bias=True,
|
||||
activation="swiglu_oai")
|
||||
activation="swigluoai")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(x)
|
||||
|
||||
Reference in New Issue
Block a user