mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[Performance][DP/EP] Add silu_mul_per_token_group_quant_fp8_colmajor kernel (#29470)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
dd5d1ef780
commit
19bee6d12d
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
from .utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
|
||||
GROUP_SIZE = 128
|
||||
FLOAT8_T = torch.float8_e4m3fn
|
||||
|
||||
|
||||
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
|
||||
print(
|
||||
f"Note : The timings reported above is for {cuda_graph_nops} "
|
||||
"consecutive invocations of the benchmarking functions. "
|
||||
f"Please divide by {cuda_graph_nops} for single invocation "
|
||||
"timings."
|
||||
)
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
|
||||
REFERENCE = 2
|
||||
|
||||
def get_impl(self):
|
||||
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor
|
||||
elif self == ImplType.REFERENCE:
|
||||
return reference
|
||||
raise ValueError(f"Unrecognized ImplType {self}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkTensors:
|
||||
input: torch.Tensor
|
||||
output: torch.Tensor
|
||||
|
||||
# Reference act output tensor
|
||||
ref_act_out: torch.Tensor
|
||||
ref_quant_out: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make(T: int, N: int) -> "BenchmarkTensors":
|
||||
assert T % GROUP_SIZE == 0
|
||||
assert N % (GROUP_SIZE * 2) == 0
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# silu_mul_per_token_group_quant_fp8_colmajor output.
|
||||
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
|
||||
FLOAT8_T
|
||||
)
|
||||
|
||||
# reference output.
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
ref_quant_out = torch.empty(
|
||||
(T, N // 2), dtype=torch.bfloat16, device="cuda"
|
||||
).to(FLOAT8_T)
|
||||
|
||||
return BenchmarkTensors(
|
||||
input=input,
|
||||
output=output,
|
||||
ref_act_out=ref_act_out,
|
||||
ref_quant_out=ref_quant_out,
|
||||
)
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return self.input.size(0)
|
||||
|
||||
@property
|
||||
def N(self):
|
||||
return self.input.size(1)
|
||||
|
||||
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
|
||||
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||
return {
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||
}
|
||||
elif impl_type == ImplType.REFERENCE:
|
||||
return {
|
||||
"input": self.input,
|
||||
"act_out": self.ref_act_out,
|
||||
"quant_out": self.ref_quant_out,
|
||||
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||
}
|
||||
raise ValueError(f"Unrecognized impl_type {impl_type}")
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
assert quant_out.size() == x.size()
|
||||
# Allocate the scale tensor column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_q = quant_out
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_T)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(
|
||||
input: torch.Tensor,
|
||||
act_out: torch.Tensor,
|
||||
quant_out: torch.Tensor,
|
||||
use_ue8m0: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
torch.ops._C.silu_and_mul(act_out, input)
|
||||
return reference_quant(act_out, quant_out, use_ue8m0)
|
||||
|
||||
|
||||
def bench_impl(
|
||||
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
|
||||
) -> TMeasurement:
|
||||
T = bench_tensors[0].T
|
||||
N = bench_tensors[0].N
|
||||
|
||||
arg_pool_size = len(bench_tensors)
|
||||
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
|
||||
|
||||
# warmup
|
||||
for kwargs in kwargs_list:
|
||||
impl_type.get_impl()(**kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Merge into a single kwargs and qualify arguments as ArgPool
|
||||
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
||||
for _kwargs in kwargs_list:
|
||||
for k, v in _kwargs.items():
|
||||
kwargs[k].values.append(v)
|
||||
|
||||
cuda_graph_params = None
|
||||
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
|
||||
timer = None
|
||||
with Bench(
|
||||
cuda_graph_params,
|
||||
"silu-mul-quant",
|
||||
f"num_tokens={T}, N={N}",
|
||||
impl_type.name,
|
||||
impl_type.get_impl(),
|
||||
**kwargs,
|
||||
) as bench:
|
||||
timer = bench.run()
|
||||
return timer
|
||||
|
||||
|
||||
def test_correctness(T: int, N: int):
|
||||
print(f"Testing num_tokens={T}, N={N} ...")
|
||||
|
||||
bench_tensor = BenchmarkTensors.make(T, N)
|
||||
|
||||
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
|
||||
|
||||
# reference output
|
||||
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
|
||||
|
||||
# test ouptut
|
||||
out_q, out_s = output_from_impl(
|
||||
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
|
||||
torch.testing.assert_close(ref_out_s, out_s)
|
||||
|
||||
|
||||
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
|
||||
timers = []
|
||||
for N, T in product(Ns, Ts):
|
||||
test_correctness(T, N)
|
||||
|
||||
bench_tensors: list[BenchmarkTensors] = [
|
||||
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
|
||||
]
|
||||
|
||||
silu_mul_quant_timer = bench_impl(
|
||||
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||
)
|
||||
timers.append(silu_mul_quant_timer)
|
||||
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
|
||||
timers.append(reference_timer)
|
||||
|
||||
print_timers(
|
||||
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
|
||||
)
|
||||
|
||||
print_timers(timers, cuda_graph_nops=arg_pool_size)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
|
||||
N = [2048, 4096, 8192]
|
||||
|
||||
print(f"T = {T}, N = {N}")
|
||||
run(T, N, arg_pool_size=8)
|
||||
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
FLOAT8_DTYPE = torch.float8_e4m3fn
|
||||
GROUP_SIZE = 128
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
|
||||
|
||||
# Allocate the scale tensor in column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_DTYPE)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
T, N = x.size()
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
torch.ops._C.silu_and_mul(ref_act_out, x)
|
||||
return reference_quant(ref_act_out, use_ue8m0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("T", [128, 256, 512])
|
||||
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
|
||||
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Test
|
||||
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input, use_ue8m0=use_ue8m0
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_output, ref_output_scales = reference(input, use_ue8m0)
|
||||
|
||||
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
|
||||
torch.testing.assert_close(output_scales, ref_output_scales)
|
||||
@@ -2,9 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as env
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -25,12 +23,12 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
from vllm.utils.func_utils import run_once
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -108,70 +106,6 @@ def _valid_deep_gemm(
|
||||
return True
|
||||
|
||||
|
||||
@run_once
|
||||
def warmup_deepgemm_gg_contiguous_kernels(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int,
|
||||
):
|
||||
"""
|
||||
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
|
||||
input tensor shapes. In this function, we construct all possible input
|
||||
tensor shapes so all the kernels are JIT'ed and cached.
|
||||
Note that this warmup is expected to happen during the model profile
|
||||
call and not during actual model inference.
|
||||
"""
|
||||
|
||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
# This is the maximum GroupedGemm M size that we expect to run
|
||||
# the grouped_gemm with.
|
||||
MAX_M = compute_aligned_M(
|
||||
env.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
num_topk,
|
||||
num_experts,
|
||||
block_m,
|
||||
expert_tokens_meta=None,
|
||||
)
|
||||
# Distribute expert-ids evenly.
|
||||
MAX_BLOCKS = MAX_M // block_m
|
||||
expert_ids_block = torch.randint(
|
||||
low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32
|
||||
)
|
||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
||||
|
||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
||||
_, n, k = w.size()
|
||||
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
|
||||
a1q_scales = torch.empty(
|
||||
(MAX_M, k // block_m), device=device, dtype=torch.float32
|
||||
)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(
|
||||
total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})"
|
||||
)
|
||||
num_tokens = MAX_M
|
||||
while num_tokens > 0:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
||||
(w, w_scale),
|
||||
out[:num_tokens],
|
||||
expert_ids[:num_tokens],
|
||||
)
|
||||
pbar.update(1)
|
||||
num_tokens = num_tokens - block_m
|
||||
|
||||
_warmup(w1, w1_scale)
|
||||
_warmup(w2, w2_scale)
|
||||
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
@@ -215,11 +149,32 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
assert M_sum % block_m == 0
|
||||
|
||||
workspace1 = (M_sum, N)
|
||||
workspace2 = (M_sum, max(N // 2, K))
|
||||
workspace1 = (M_sum, max(N // 2, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def _act_mul_quant(
|
||||
self, input: torch.Tensor, output: torch.Tensor, activation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if activation == "silu":
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input, output=output
|
||||
)
|
||||
else:
|
||||
# This is a fallback path. If we find ourselves using any activation other
|
||||
# than silu, we should add that activation to
|
||||
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty(
|
||||
(M_sum, N // 2), dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.activation(activation, act_out, input)
|
||||
assert self.block_shape is not None
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, self.block_shape[1], column_major_scales=True, out_q=output
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@@ -261,14 +216,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K))
|
||||
mm1_out = _resize_cache(workspace13, (M_sum, N))
|
||||
act_out = _resize_cache(workspace2, (M_sum, N // 2))
|
||||
quant_out = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
|
||||
a1q_perm = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
|
||||
)
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
|
||||
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
|
||||
aq=a1q,
|
||||
aq_scale=a1q_scale,
|
||||
@@ -280,17 +230,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
assert a1q.size(0) == M_sum
|
||||
|
||||
mm1_out = _resize_cache(workspace2, (M_sum, N))
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
|
||||
)
|
||||
|
||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||
|
||||
a2q_scale: torch.Tensor | None = None
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(
|
||||
act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out
|
||||
quant_out = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
|
||||
)
|
||||
a2q, a2q_scale = self._act_mul_quant(
|
||||
input=mm1_out.view(-1, N), output=quant_out, activation=activation
|
||||
)
|
||||
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
|
||||
)
|
||||
|
||||
@@ -492,6 +492,139 @@ def _per_token_group_quant_fp8(
|
||||
tl.store(y_s_ptr, y_s)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
y_ptr, # [M, N]
|
||||
y_q_ptr, # [M, N // 2]
|
||||
y_s_ptr, # [M, (N // 2) // GROUP_SIZE]
|
||||
M, # num tokens
|
||||
N, # intermediate size
|
||||
# Stride
|
||||
y_s_col_stride: tl.int64,
|
||||
# Information for float8
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks.
|
||||
"""
|
||||
Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then
|
||||
the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills
|
||||
the outputs tensors at the right positions.
|
||||
"""
|
||||
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
N_2 = N // 2
|
||||
|
||||
m_offset = pid_m * BLOCK_M
|
||||
n_offset = pid_n * BLOCK_N
|
||||
if m_offset >= M:
|
||||
return
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
|
||||
offs_m = tl.arange(0, BLOCK_M).to(tl.int64)
|
||||
|
||||
base_y_ptr = y_ptr + m_offset * N + n_offset
|
||||
|
||||
act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
|
||||
act_in = tl.load(act_in_ptrs)
|
||||
mul_in = tl.load(act_in_ptrs + N_2)
|
||||
|
||||
# silu & mul
|
||||
act_in = act_in.to(tl.float32)
|
||||
one_f32 = tl.cast(1, tl.float32)
|
||||
silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
|
||||
y = (silu_out * mul_in).to(tl.float32)
|
||||
|
||||
# quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_s = tl.reshape(y_s, (BLOCK_M, 1))
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
# store y_q
|
||||
base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset
|
||||
y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :]
|
||||
tl.store(y_q_ptrs, y_q)
|
||||
|
||||
# store y_s
|
||||
group_id = n_offset // GROUP_SIZE
|
||||
base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset
|
||||
y_s_ptrs = base_y_s_ptr + offs_m
|
||||
y_s = tl.reshape(y_s, (BLOCK_M,))
|
||||
tl.store(y_s_ptrs, y_s)
|
||||
|
||||
|
||||
def silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input: torch.Tensor, # [M, N]
|
||||
output: torch.Tensor | None = None, # [M, N // 2]
|
||||
use_ue8m0: bool | None = None,
|
||||
eps: float = 1e-10,
|
||||
):
|
||||
"""
|
||||
silu+mul + block-fp8 quant with group size 128.
|
||||
"""
|
||||
GROUP_SIZE = 128
|
||||
assert input.ndim == 2
|
||||
if output is not None:
|
||||
assert output.ndim == 2
|
||||
assert input.size(0) % GROUP_SIZE == 0
|
||||
assert input.size(1) % (GROUP_SIZE * 2) == 0
|
||||
|
||||
if use_ue8m0 is None:
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
M, N = input.size()
|
||||
N_2 = N // 2
|
||||
|
||||
if output is None:
|
||||
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
|
||||
|
||||
output_scales = torch.empty(
|
||||
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
|
||||
).transpose(0, 1)
|
||||
|
||||
BLOCK_M = 8
|
||||
BLOCK_N = GROUP_SIZE
|
||||
assert M % BLOCK_M == 0
|
||||
assert N_2 % BLOCK_N == 0
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
# Force even division so we can avoid edgecases within the kernel.
|
||||
assert M % BLOCK_M == 0
|
||||
assert N_2 % BLOCK_N == 0
|
||||
grid = (M // BLOCK_M, N_2 // BLOCK_N)
|
||||
|
||||
_silu_mul_per_token_group_quant_fp8_colmajor[grid](
|
||||
input,
|
||||
output,
|
||||
output_scales,
|
||||
M,
|
||||
N,
|
||||
output_scales.stride(-1),
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
use_ue8m0,
|
||||
GROUP_SIZE,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
)
|
||||
|
||||
return output, output_scales
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8_colmajor(
|
||||
# Pointers to inputs and output
|
||||
|
||||
Reference in New Issue
Block a user