mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[Performance][DP/EP] Add silu_mul_per_token_group_quant_fp8_colmajor kernel (#29470)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
dd5d1ef780
commit
19bee6d12d
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import product
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
_per_token_group_quant_fp8_colmajor,
|
||||||
|
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||||
|
)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
|
from .utils import ArgPool, Bench, CudaGraphBenchParams
|
||||||
|
|
||||||
|
GROUP_SIZE = 128
|
||||||
|
FLOAT8_T = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
|
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
|
||||||
|
print(
|
||||||
|
f"Note : The timings reported above is for {cuda_graph_nops} "
|
||||||
|
"consecutive invocations of the benchmarking functions. "
|
||||||
|
f"Please divide by {cuda_graph_nops} for single invocation "
|
||||||
|
"timings."
|
||||||
|
)
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
class ImplType(Enum):
|
||||||
|
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
|
||||||
|
REFERENCE = 2
|
||||||
|
|
||||||
|
def get_impl(self):
|
||||||
|
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||||
|
return silu_mul_per_token_group_quant_fp8_colmajor
|
||||||
|
elif self == ImplType.REFERENCE:
|
||||||
|
return reference
|
||||||
|
raise ValueError(f"Unrecognized ImplType {self}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkTensors:
|
||||||
|
input: torch.Tensor
|
||||||
|
output: torch.Tensor
|
||||||
|
|
||||||
|
# Reference act output tensor
|
||||||
|
ref_act_out: torch.Tensor
|
||||||
|
ref_quant_out: torch.Tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(T: int, N: int) -> "BenchmarkTensors":
|
||||||
|
assert T % GROUP_SIZE == 0
|
||||||
|
assert N % (GROUP_SIZE * 2) == 0
|
||||||
|
|
||||||
|
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
# silu_mul_per_token_group_quant_fp8_colmajor output.
|
||||||
|
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
|
||||||
|
FLOAT8_T
|
||||||
|
)
|
||||||
|
|
||||||
|
# reference output.
|
||||||
|
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||||
|
ref_quant_out = torch.empty(
|
||||||
|
(T, N // 2), dtype=torch.bfloat16, device="cuda"
|
||||||
|
).to(FLOAT8_T)
|
||||||
|
|
||||||
|
return BenchmarkTensors(
|
||||||
|
input=input,
|
||||||
|
output=output,
|
||||||
|
ref_act_out=ref_act_out,
|
||||||
|
ref_quant_out=ref_quant_out,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def T(self):
|
||||||
|
return self.input.size(0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def N(self):
|
||||||
|
return self.input.size(1)
|
||||||
|
|
||||||
|
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
|
||||||
|
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||||
|
return {
|
||||||
|
"input": self.input,
|
||||||
|
"output": self.output,
|
||||||
|
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||||
|
}
|
||||||
|
elif impl_type == ImplType.REFERENCE:
|
||||||
|
return {
|
||||||
|
"input": self.input,
|
||||||
|
"act_out": self.ref_act_out,
|
||||||
|
"quant_out": self.ref_quant_out,
|
||||||
|
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||||
|
}
|
||||||
|
raise ValueError(f"Unrecognized impl_type {impl_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
|
||||||
|
"""
|
||||||
|
Reference triton quant kernel from,
|
||||||
|
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||||
|
"""
|
||||||
|
assert quant_out.size() == x.size()
|
||||||
|
# Allocate the scale tensor column-major format.
|
||||||
|
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||||
|
x_q = quant_out
|
||||||
|
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||||
|
|
||||||
|
M = x.numel() // GROUP_SIZE
|
||||||
|
N = GROUP_SIZE
|
||||||
|
BLOCK = triton.next_power_of_2(N)
|
||||||
|
# heuristics for number of warps
|
||||||
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
|
finfo = torch.finfo(FLOAT8_T)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
x_s,
|
||||||
|
GROUP_SIZE,
|
||||||
|
x.shape[1],
|
||||||
|
x.stride(0),
|
||||||
|
x_s.stride(1),
|
||||||
|
eps=1e-10,
|
||||||
|
fp8_min=fp8_min,
|
||||||
|
fp8_max=fp8_max,
|
||||||
|
use_ue8m0=use_ue8m0,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def reference(
|
||||||
|
input: torch.Tensor,
|
||||||
|
act_out: torch.Tensor,
|
||||||
|
quant_out: torch.Tensor,
|
||||||
|
use_ue8m0: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
torch.ops._C.silu_and_mul(act_out, input)
|
||||||
|
return reference_quant(act_out, quant_out, use_ue8m0)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_impl(
|
||||||
|
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
|
||||||
|
) -> TMeasurement:
|
||||||
|
T = bench_tensors[0].T
|
||||||
|
N = bench_tensors[0].N
|
||||||
|
|
||||||
|
arg_pool_size = len(bench_tensors)
|
||||||
|
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for kwargs in kwargs_list:
|
||||||
|
impl_type.get_impl()(**kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Merge into a single kwargs and qualify arguments as ArgPool
|
||||||
|
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
||||||
|
for _kwargs in kwargs_list:
|
||||||
|
for k, v in _kwargs.items():
|
||||||
|
kwargs[k].values.append(v)
|
||||||
|
|
||||||
|
cuda_graph_params = None
|
||||||
|
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
|
||||||
|
timer = None
|
||||||
|
with Bench(
|
||||||
|
cuda_graph_params,
|
||||||
|
"silu-mul-quant",
|
||||||
|
f"num_tokens={T}, N={N}",
|
||||||
|
impl_type.name,
|
||||||
|
impl_type.get_impl(),
|
||||||
|
**kwargs,
|
||||||
|
) as bench:
|
||||||
|
timer = bench.run()
|
||||||
|
return timer
|
||||||
|
|
||||||
|
|
||||||
|
def test_correctness(T: int, N: int):
|
||||||
|
print(f"Testing num_tokens={T}, N={N} ...")
|
||||||
|
|
||||||
|
bench_tensor = BenchmarkTensors.make(T, N)
|
||||||
|
|
||||||
|
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
|
||||||
|
|
||||||
|
# reference output
|
||||||
|
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
|
||||||
|
|
||||||
|
# test ouptut
|
||||||
|
out_q, out_s = output_from_impl(
|
||||||
|
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
|
||||||
|
torch.testing.assert_close(ref_out_s, out_s)
|
||||||
|
|
||||||
|
|
||||||
|
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
|
||||||
|
timers = []
|
||||||
|
for N, T in product(Ns, Ts):
|
||||||
|
test_correctness(T, N)
|
||||||
|
|
||||||
|
bench_tensors: list[BenchmarkTensors] = [
|
||||||
|
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
silu_mul_quant_timer = bench_impl(
|
||||||
|
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||||
|
)
|
||||||
|
timers.append(silu_mul_quant_timer)
|
||||||
|
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
|
||||||
|
timers.append(reference_timer)
|
||||||
|
|
||||||
|
print_timers(
|
||||||
|
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
|
||||||
|
)
|
||||||
|
|
||||||
|
print_timers(timers, cuda_graph_nops=arg_pool_size)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
|
||||||
|
N = [2048, 4096, 8192]
|
||||||
|
|
||||||
|
print(f"T = {T}, N = {N}")
|
||||||
|
run(T, N, arg_pool_size=8)
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
_per_token_group_quant_fp8_colmajor,
|
||||||
|
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
|
FLOAT8_DTYPE = torch.float8_e4m3fn
|
||||||
|
GROUP_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
|
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
|
||||||
|
"""
|
||||||
|
Reference triton quant kernel from,
|
||||||
|
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||||
|
"""
|
||||||
|
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
|
||||||
|
|
||||||
|
# Allocate the scale tensor in column-major format.
|
||||||
|
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||||
|
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||||
|
|
||||||
|
M = x.numel() // GROUP_SIZE
|
||||||
|
N = GROUP_SIZE
|
||||||
|
BLOCK = triton.next_power_of_2(N)
|
||||||
|
# heuristics for number of warps
|
||||||
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
|
finfo = torch.finfo(FLOAT8_DTYPE)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
x_s,
|
||||||
|
GROUP_SIZE,
|
||||||
|
x.shape[1],
|
||||||
|
x.stride(0),
|
||||||
|
x_s.stride(1),
|
||||||
|
eps=1e-10,
|
||||||
|
fp8_min=fp8_min,
|
||||||
|
fp8_max=fp8_max,
|
||||||
|
use_ue8m0=use_ue8m0,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
T, N = x.size()
|
||||||
|
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||||
|
torch.ops._C.silu_and_mul(ref_act_out, x)
|
||||||
|
return reference_quant(ref_act_out, use_ue8m0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("T", [128, 256, 512])
|
||||||
|
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
|
||||||
|
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
|
||||||
|
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
|
||||||
|
# Test
|
||||||
|
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
|
||||||
|
input, use_ue8m0=use_ue8m0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reference
|
||||||
|
ref_output, ref_output_scales = reference(input, use_ue8m0)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
|
||||||
|
torch.testing.assert_close(output_scales, ref_output_scales)
|
||||||
@@ -2,9 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import vllm.envs as env
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
@@ -25,12 +23,12 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
|
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||||
)
|
)
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
get_mk_alignment_for_contiguous_layout,
|
get_mk_alignment_for_contiguous_layout,
|
||||||
m_grouped_fp8_gemm_nt_contiguous,
|
m_grouped_fp8_gemm_nt_contiguous,
|
||||||
)
|
)
|
||||||
from vllm.utils.func_utils import run_once
|
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -108,70 +106,6 @@ def _valid_deep_gemm(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@run_once
|
|
||||||
def warmup_deepgemm_gg_contiguous_kernels(
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
num_topk: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
|
|
||||||
input tensor shapes. In this function, we construct all possible input
|
|
||||||
tensor shapes so all the kernels are JIT'ed and cached.
|
|
||||||
Note that this warmup is expected to happen during the model profile
|
|
||||||
call and not during actual model inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
|
||||||
|
|
||||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
|
||||||
num_experts = w1.size(0)
|
|
||||||
device = w1.device
|
|
||||||
|
|
||||||
# This is the maximum GroupedGemm M size that we expect to run
|
|
||||||
# the grouped_gemm with.
|
|
||||||
MAX_M = compute_aligned_M(
|
|
||||||
env.VLLM_FUSED_MOE_CHUNK_SIZE,
|
|
||||||
num_topk,
|
|
||||||
num_experts,
|
|
||||||
block_m,
|
|
||||||
expert_tokens_meta=None,
|
|
||||||
)
|
|
||||||
# Distribute expert-ids evenly.
|
|
||||||
MAX_BLOCKS = MAX_M // block_m
|
|
||||||
expert_ids_block = torch.randint(
|
|
||||||
low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
|
||||||
|
|
||||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
|
||||||
_, n, k = w.size()
|
|
||||||
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
|
|
||||||
a1q_scales = torch.empty(
|
|
||||||
(MAX_M, k // block_m), device=device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
pbar = tqdm(
|
|
||||||
total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})"
|
|
||||||
)
|
|
||||||
num_tokens = MAX_M
|
|
||||||
while num_tokens > 0:
|
|
||||||
m_grouped_fp8_gemm_nt_contiguous(
|
|
||||||
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
|
||||||
(w, w_scale),
|
|
||||||
out[:num_tokens],
|
|
||||||
expert_ids[:num_tokens],
|
|
||||||
)
|
|
||||||
pbar.update(1)
|
|
||||||
num_tokens = num_tokens - block_m
|
|
||||||
|
|
||||||
_warmup(w1, w1_scale)
|
|
||||||
_warmup(w2, w2_scale)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
@@ -215,11 +149,32 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
)
|
)
|
||||||
assert M_sum % block_m == 0
|
assert M_sum % block_m == 0
|
||||||
|
|
||||||
workspace1 = (M_sum, N)
|
workspace1 = (M_sum, max(N // 2, K))
|
||||||
workspace2 = (M_sum, max(N // 2, K))
|
workspace2 = (M_sum, max(N, K))
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
|
def _act_mul_quant(
|
||||||
|
self, input: torch.Tensor, output: torch.Tensor, activation: str
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if activation == "silu":
|
||||||
|
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||||
|
input=input, output=output
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This is a fallback path. If we find ourselves using any activation other
|
||||||
|
# than silu, we should add that activation to
|
||||||
|
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
|
||||||
|
M_sum, N = input.size()
|
||||||
|
act_out = torch.empty(
|
||||||
|
(M_sum, N // 2), dtype=input.dtype, device=input.device
|
||||||
|
)
|
||||||
|
self.activation(activation, act_out, input)
|
||||||
|
assert self.block_shape is not None
|
||||||
|
return per_token_group_quant_fp8(
|
||||||
|
act_out, self.block_shape[1], column_major_scales=True, out_q=output
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
@@ -261,14 +216,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K))
|
a1q_perm = _resize_cache(
|
||||||
mm1_out = _resize_cache(workspace13, (M_sum, N))
|
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
|
||||||
act_out = _resize_cache(workspace2, (M_sum, N // 2))
|
|
||||||
quant_out = _resize_cache(
|
|
||||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
|
|
||||||
)
|
)
|
||||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
|
||||||
|
|
||||||
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
|
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
|
||||||
aq=a1q,
|
aq=a1q,
|
||||||
aq_scale=a1q_scale,
|
aq_scale=a1q_scale,
|
||||||
@@ -280,17 +230,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
)
|
)
|
||||||
assert a1q.size(0) == M_sum
|
assert a1q.size(0) == M_sum
|
||||||
|
|
||||||
|
mm1_out = _resize_cache(workspace2, (M_sum, N))
|
||||||
m_grouped_fp8_gemm_nt_contiguous(
|
m_grouped_fp8_gemm_nt_contiguous(
|
||||||
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
|
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
quant_out = _resize_cache(
|
||||||
|
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
|
||||||
a2q_scale: torch.Tensor | None = None
|
)
|
||||||
a2q, a2q_scale = per_token_group_quant_fp8(
|
a2q, a2q_scale = self._act_mul_quant(
|
||||||
act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out
|
input=mm1_out.view(-1, N), output=quant_out, activation=activation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||||
m_grouped_fp8_gemm_nt_contiguous(
|
m_grouped_fp8_gemm_nt_contiguous(
|
||||||
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
|
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -492,6 +492,139 @@ def _per_token_group_quant_fp8(
|
|||||||
tl.store(y_s_ptr, y_s)
|
tl.store(y_s_ptr, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _silu_mul_per_token_group_quant_fp8_colmajor(
|
||||||
|
y_ptr, # [M, N]
|
||||||
|
y_q_ptr, # [M, N // 2]
|
||||||
|
y_s_ptr, # [M, (N // 2) // GROUP_SIZE]
|
||||||
|
M, # num tokens
|
||||||
|
N, # intermediate size
|
||||||
|
# Stride
|
||||||
|
y_s_col_stride: tl.int64,
|
||||||
|
# Information for float8
|
||||||
|
eps,
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
use_ue8m0: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
GROUP_SIZE: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
# TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks.
|
||||||
|
"""
|
||||||
|
Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then
|
||||||
|
the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills
|
||||||
|
the outputs tensors at the right positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pid_m = tl.program_id(0)
|
||||||
|
pid_n = tl.program_id(1)
|
||||||
|
N_2 = N // 2
|
||||||
|
|
||||||
|
m_offset = pid_m * BLOCK_M
|
||||||
|
n_offset = pid_n * BLOCK_N
|
||||||
|
if m_offset >= M:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
|
||||||
|
offs_m = tl.arange(0, BLOCK_M).to(tl.int64)
|
||||||
|
|
||||||
|
base_y_ptr = y_ptr + m_offset * N + n_offset
|
||||||
|
|
||||||
|
act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||||
|
|
||||||
|
act_in = tl.load(act_in_ptrs)
|
||||||
|
mul_in = tl.load(act_in_ptrs + N_2)
|
||||||
|
|
||||||
|
# silu & mul
|
||||||
|
act_in = act_in.to(tl.float32)
|
||||||
|
one_f32 = tl.cast(1, tl.float32)
|
||||||
|
silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
|
||||||
|
y = (silu_out * mul_in).to(tl.float32)
|
||||||
|
|
||||||
|
# quant
|
||||||
|
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
|
||||||
|
scale_raw = _absmax / fp8_max
|
||||||
|
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||||
|
y_s = tl.reshape(y_s, (BLOCK_M, 1))
|
||||||
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
# store y_q
|
||||||
|
base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset
|
||||||
|
y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :]
|
||||||
|
tl.store(y_q_ptrs, y_q)
|
||||||
|
|
||||||
|
# store y_s
|
||||||
|
group_id = n_offset // GROUP_SIZE
|
||||||
|
base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset
|
||||||
|
y_s_ptrs = base_y_s_ptr + offs_m
|
||||||
|
y_s = tl.reshape(y_s, (BLOCK_M,))
|
||||||
|
tl.store(y_s_ptrs, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
def silu_mul_per_token_group_quant_fp8_colmajor(
|
||||||
|
input: torch.Tensor, # [M, N]
|
||||||
|
output: torch.Tensor | None = None, # [M, N // 2]
|
||||||
|
use_ue8m0: bool | None = None,
|
||||||
|
eps: float = 1e-10,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
silu+mul + block-fp8 quant with group size 128.
|
||||||
|
"""
|
||||||
|
GROUP_SIZE = 128
|
||||||
|
assert input.ndim == 2
|
||||||
|
if output is not None:
|
||||||
|
assert output.ndim == 2
|
||||||
|
assert input.size(0) % GROUP_SIZE == 0
|
||||||
|
assert input.size(1) % (GROUP_SIZE * 2) == 0
|
||||||
|
|
||||||
|
if use_ue8m0 is None:
|
||||||
|
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
|
||||||
|
M, N = input.size()
|
||||||
|
N_2 = N // 2
|
||||||
|
|
||||||
|
if output is None:
|
||||||
|
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
|
||||||
|
|
||||||
|
output_scales = torch.empty(
|
||||||
|
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
|
||||||
|
).transpose(0, 1)
|
||||||
|
|
||||||
|
BLOCK_M = 8
|
||||||
|
BLOCK_N = GROUP_SIZE
|
||||||
|
assert M % BLOCK_M == 0
|
||||||
|
assert N_2 % BLOCK_N == 0
|
||||||
|
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
# Force even division so we can avoid edgecases within the kernel.
|
||||||
|
assert M % BLOCK_M == 0
|
||||||
|
assert N_2 % BLOCK_N == 0
|
||||||
|
grid = (M // BLOCK_M, N_2 // BLOCK_N)
|
||||||
|
|
||||||
|
_silu_mul_per_token_group_quant_fp8_colmajor[grid](
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
output_scales,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
output_scales.stride(-1),
|
||||||
|
eps,
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
use_ue8m0,
|
||||||
|
GROUP_SIZE,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output, output_scales
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _per_token_group_quant_fp8_colmajor(
|
def _per_token_group_quant_fp8_colmajor(
|
||||||
# Pointers to inputs and output
|
# Pointers to inputs and output
|
||||||
|
|||||||
Reference in New Issue
Block a user