[Performance][DP/EP] Add silu_mul_per_token_group_quant_fp8_colmajor kernel (#29470)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-12-03 13:04:59 -05:00
committed by GitHub
parent dd5d1ef780
commit 19bee6d12d
4 changed files with 496 additions and 81 deletions

View File

@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import Enum
from itertools import product
from typing import Any
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from .utils import ArgPool, Bench, CudaGraphBenchParams
GROUP_SIZE = 128
FLOAT8_T = torch.float8_e4m3fn
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
print(
f"Note : The timings reported above is for {cuda_graph_nops} "
"consecutive invocations of the benchmarking functions. "
f"Please divide by {cuda_graph_nops} for single invocation "
"timings."
)
compare = TBenchmark.Compare(timers)
compare.print()
class ImplType(Enum):
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
REFERENCE = 2
def get_impl(self):
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
return silu_mul_per_token_group_quant_fp8_colmajor
elif self == ImplType.REFERENCE:
return reference
raise ValueError(f"Unrecognized ImplType {self}")
@dataclass
class BenchmarkTensors:
input: torch.Tensor
output: torch.Tensor
# Reference act output tensor
ref_act_out: torch.Tensor
ref_quant_out: torch.Tensor
@staticmethod
def make(T: int, N: int) -> "BenchmarkTensors":
assert T % GROUP_SIZE == 0
assert N % (GROUP_SIZE * 2) == 0
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
# silu_mul_per_token_group_quant_fp8_colmajor output.
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
FLOAT8_T
)
# reference output.
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
ref_quant_out = torch.empty(
(T, N // 2), dtype=torch.bfloat16, device="cuda"
).to(FLOAT8_T)
return BenchmarkTensors(
input=input,
output=output,
ref_act_out=ref_act_out,
ref_quant_out=ref_quant_out,
)
@property
def T(self):
return self.input.size(0)
@property
def N(self):
return self.input.size(1)
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
return {
"input": self.input,
"output": self.output,
"use_ue8m0": is_deep_gemm_e8m0_used(),
}
elif impl_type == ImplType.REFERENCE:
return {
"input": self.input,
"act_out": self.ref_act_out,
"quant_out": self.ref_quant_out,
"use_ue8m0": is_deep_gemm_e8m0_used(),
}
raise ValueError(f"Unrecognized impl_type {impl_type}")
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
assert quant_out.size() == x.size()
# Allocate the scale tensor column-major format.
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
x_q = quant_out
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
M = x.numel() // GROUP_SIZE
N = GROUP_SIZE
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
finfo = torch.finfo(FLOAT8_T)
fp8_min = finfo.min
fp8_max = finfo.max
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
GROUP_SIZE,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def reference(
input: torch.Tensor,
act_out: torch.Tensor,
quant_out: torch.Tensor,
use_ue8m0: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
torch.ops._C.silu_and_mul(act_out, input)
return reference_quant(act_out, quant_out, use_ue8m0)
def bench_impl(
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
) -> TMeasurement:
T = bench_tensors[0].T
N = bench_tensors[0].N
arg_pool_size = len(bench_tensors)
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
# warmup
for kwargs in kwargs_list:
impl_type.get_impl()(**kwargs)
torch.cuda.synchronize()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
for _kwargs in kwargs_list:
for k, v in _kwargs.items():
kwargs[k].values.append(v)
cuda_graph_params = None
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
timer = None
with Bench(
cuda_graph_params,
"silu-mul-quant",
f"num_tokens={T}, N={N}",
impl_type.name,
impl_type.get_impl(),
**kwargs,
) as bench:
timer = bench.run()
return timer
def test_correctness(T: int, N: int):
print(f"Testing num_tokens={T}, N={N} ...")
bench_tensor = BenchmarkTensors.make(T, N)
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
# reference output
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
# test ouptut
out_q, out_s = output_from_impl(
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
torch.testing.assert_close(ref_out_s, out_s)
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
timers = []
for N, T in product(Ns, Ts):
test_correctness(T, N)
bench_tensors: list[BenchmarkTensors] = [
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
]
silu_mul_quant_timer = bench_impl(
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
timers.append(silu_mul_quant_timer)
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
timers.append(reference_timer)
print_timers(
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
)
print_timers(timers, cuda_graph_nops=arg_pool_size)
return timers
if __name__ == "__main__":
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
N = [2048, 4096, 8192]
print(f"T = {T}, N = {N}")
run(T, N, arg_pool_size=8)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
FLOAT8_DTYPE = torch.float8_e4m3fn
GROUP_SIZE = 128
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
# Allocate the scale tensor in column-major format.
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
M = x.numel() // GROUP_SIZE
N = GROUP_SIZE
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
finfo = torch.finfo(FLOAT8_DTYPE)
fp8_min = finfo.min
fp8_max = finfo.max
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
GROUP_SIZE,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
T, N = x.size()
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
torch.ops._C.silu_and_mul(ref_act_out, x)
return reference_quant(ref_act_out, use_ue8m0)
@pytest.mark.parametrize("T", [128, 256, 512])
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
current_platform.seed_everything(42)
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
use_ue8m0 = is_deep_gemm_e8m0_used()
# Test
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
input, use_ue8m0=use_ue8m0
)
# Reference
ref_output, ref_output_scales = reference(input, use_ue8m0)
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
torch.testing.assert_close(output_scales, ref_output_scales)

View File

@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
@@ -25,12 +23,12 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.func_utils import run_once
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
@@ -108,70 +106,6 @@ def _valid_deep_gemm(
return True
@run_once
def warmup_deepgemm_gg_contiguous_kernels(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
):
"""
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
input tensor shapes. In this function, we construct all possible input
tensor shapes so all the kernels are JIT'ed and cached.
Note that this warmup is expected to happen during the model profile
call and not during actual model inference.
"""
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
block_m = get_mk_alignment_for_contiguous_layout()[0]
num_experts = w1.size(0)
device = w1.device
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
MAX_M = compute_aligned_M(
env.VLLM_FUSED_MOE_CHUNK_SIZE,
num_topk,
num_experts,
block_m,
expert_tokens_meta=None,
)
# Distribute expert-ids evenly.
MAX_BLOCKS = MAX_M // block_m
expert_ids_block = torch.randint(
low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32
)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
a1q_scales = torch.empty(
(MAX_M, k // block_m), device=device, dtype=torch.float32
)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(
total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})"
)
num_tokens = MAX_M
while num_tokens > 0:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]),
(w, w_scale),
out[:num_tokens],
expert_ids[:num_tokens],
)
pbar.update(1)
num_tokens = num_tokens - block_m
_warmup(w1, w1_scale)
_warmup(w2, w2_scale)
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
@@ -215,11 +149,32 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
assert M_sum % block_m == 0
workspace1 = (M_sum, N)
workspace2 = (M_sum, max(N // 2, K))
workspace1 = (M_sum, max(N // 2, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: str
) -> tuple[torch.Tensor, torch.Tensor]:
if activation == "silu":
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input, output=output
)
else:
# This is a fallback path. If we find ourselves using any activation other
# than silu, we should add that activation to
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
M_sum, N = input.size()
act_out = torch.empty(
(M_sum, N // 2), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
assert self.block_shape is not None
return per_token_group_quant_fp8(
act_out, self.block_shape[1], column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
@@ -261,14 +216,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K))
mm1_out = _resize_cache(workspace13, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
@@ -280,17 +230,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
assert a1q.size(0) == M_sum
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
)
self.activation(activation, act_out, mm1_out.view(-1, N))
a2q_scale: torch.Tensor | None = None
a2q, a2q_scale = per_token_group_quant_fp8(
act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
)

View File

@@ -492,6 +492,139 @@ def _per_token_group_quant_fp8(
tl.store(y_s_ptr, y_s)
@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor(
y_ptr, # [M, N]
y_q_ptr, # [M, N // 2]
y_s_ptr, # [M, (N // 2) // GROUP_SIZE]
M, # num tokens
N, # intermediate size
# Stride
y_s_col_stride: tl.int64,
# Information for float8
eps,
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks.
"""
Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then
the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills
the outputs tensors at the right positions.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
N_2 = N // 2
m_offset = pid_m * BLOCK_M
n_offset = pid_n * BLOCK_N
if m_offset >= M:
return
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_m = tl.arange(0, BLOCK_M).to(tl.int64)
base_y_ptr = y_ptr + m_offset * N + n_offset
act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :]
act_in = tl.load(act_in_ptrs)
mul_in = tl.load(act_in_ptrs + N_2)
# silu & mul
act_in = act_in.to(tl.float32)
one_f32 = tl.cast(1, tl.float32)
silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
y = (silu_out * mul_in).to(tl.float32)
# quant
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_s = tl.reshape(y_s, (BLOCK_M, 1))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
# store y_q
base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset
y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :]
tl.store(y_q_ptrs, y_q)
# store y_s
group_id = n_offset // GROUP_SIZE
base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset
y_s_ptrs = base_y_s_ptr + offs_m
y_s = tl.reshape(y_s, (BLOCK_M,))
tl.store(y_s_ptrs, y_s)
def silu_mul_per_token_group_quant_fp8_colmajor(
input: torch.Tensor, # [M, N]
output: torch.Tensor | None = None, # [M, N // 2]
use_ue8m0: bool | None = None,
eps: float = 1e-10,
):
"""
silu+mul + block-fp8 quant with group size 128.
"""
GROUP_SIZE = 128
assert input.ndim == 2
if output is not None:
assert output.ndim == 2
assert input.size(0) % GROUP_SIZE == 0
assert input.size(1) % (GROUP_SIZE * 2) == 0
if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()
M, N = input.size()
N_2 = N // 2
if output is None:
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
output_scales = torch.empty(
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
).transpose(0, 1)
BLOCK_M = 8
BLOCK_N = GROUP_SIZE
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
# Force even division so we can avoid edgecases within the kernel.
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
grid = (M // BLOCK_M, N_2 // BLOCK_N)
_silu_mul_per_token_group_quant_fp8_colmajor[grid](
input,
output,
output_scales,
M,
N,
output_scales.stride(-1),
eps,
fp8_min,
fp8_max,
use_ue8m0,
GROUP_SIZE,
BLOCK_M,
BLOCK_N,
)
return output, output_scales
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output