2146 lines
72 KiB
Python
2146 lines
72 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Fused MoE Triton kernels."""
|
|
|
|
import functools
|
|
import json
|
|
import os
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import vllm.envs as envs
|
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
from vllm import _custom_ops as ops
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
vllm_is_batch_invariant,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
|
FusedMoEQuantConfig,
|
|
_get_config_dtype_str,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
_valid_cutlass_block_scaled_grouped_gemm,
|
|
run_cutlass_block_scaled_fused_experts,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
|
_valid_deep_gemm,
|
|
deep_gemm_moe_fp8,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
|
moe_align_block_size,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
MoEPrepareAndFinalizeNoEP,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|
TopKWeightAndReduceNoOP,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.utils import (
|
|
_resize_cache,
|
|
activation_without_mul,
|
|
disable_inplace,
|
|
moe_kernel_quantize_input,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
|
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
|
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
|
|
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@triton.jit
|
|
def write_zeros_to_output(
|
|
c_ptr,
|
|
stride_cm,
|
|
stride_cn,
|
|
pid_n,
|
|
N,
|
|
offs_token,
|
|
token_mask,
|
|
BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N,
|
|
compute_type,
|
|
):
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
|
|
|
|
@triton.jit
|
|
def fused_moe_kernel_gptq_awq(
|
|
# Pointers to matrices
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
b_scale_ptr,
|
|
b_zp_ptr,
|
|
topk_weights_ptr,
|
|
sorted_token_ids_ptr,
|
|
expert_ids_ptr,
|
|
num_tokens_post_padded_ptr,
|
|
# Matrix dimensions
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
EM,
|
|
num_valid_tokens,
|
|
# The stride variables represent how much to increase the ptr by when
|
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
|
# how much to increase `a_ptr` by to get the element one row down
|
|
# (A has M rows).
|
|
stride_am,
|
|
stride_ak,
|
|
stride_be,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_bse,
|
|
stride_bsk,
|
|
stride_bsn,
|
|
stride_bze,
|
|
stride_bzk,
|
|
stride_bzn,
|
|
block_k_diviable: tl.constexpr,
|
|
group_size: tl.constexpr,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
compute_type: tl.constexpr,
|
|
has_zp: tl.constexpr,
|
|
use_int4_w4a16: tl.constexpr,
|
|
use_int8_w8a16: tl.constexpr,
|
|
):
|
|
"""
|
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
token and expert matrices.
|
|
|
|
Key Parameters:
|
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
|
be any shape representing batches and K is the feature dimension of
|
|
each token.
|
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
|
the number of experts, K is the input feature dimension, and N is
|
|
the output feature dimension.
|
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
|
total number of tokens post padding, topk is the number of times
|
|
each token is repeated, and N is the output feature dimension.
|
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
|
repeated topk times and arranged by the expert index they are
|
|
assigned to.
|
|
- expert_ids: A tensor containing the indices of the expert for each
|
|
block. It determines which expert matrix from B should be used for
|
|
each block in A.
|
|
This kernel performs the multiplication of a token by its corresponding
|
|
expert matrix as determined by `expert_ids`. The sorting of
|
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
|
multiplication across different blocks processed by the same expert.
|
|
"""
|
|
# -----------------------------------------------------------
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
# This is done in a grouped ordering to promote L2 data reuse.
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
# ----------------------------------------------------------
|
|
# Create pointers for the first blocks of A and B.
|
|
# We will advance this pointer as we move in the K direction
|
|
# and accumulate
|
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
|
return
|
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
|
token_mask = offs_token < num_valid_tokens
|
|
|
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
if off_experts == -1:
|
|
# -----------------------------------------------------------
|
|
# Write back zeros to the output when the expert is not
|
|
# in the current expert parallel rank.
|
|
write_zeros_to_output(
|
|
c_ptr,
|
|
stride_cm,
|
|
stride_cn,
|
|
pid_n,
|
|
N,
|
|
offs_token,
|
|
token_mask,
|
|
BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N,
|
|
compute_type,
|
|
)
|
|
return
|
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (
|
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
|
)
|
|
|
|
if use_int4_w4a16:
|
|
b_ptrs = (
|
|
b_ptr
|
|
+ off_experts * stride_be
|
|
+ (offs_k[:, None] // 2) * stride_bk
|
|
+ offs_bn[None, :] * stride_bn
|
|
)
|
|
b_shifter = (offs_k[:, None] % 2) * 4
|
|
elif use_int8_w8a16:
|
|
b_ptrs = (
|
|
b_ptr
|
|
+ off_experts * stride_be
|
|
+ offs_k[:, None] * stride_bk
|
|
+ offs_bn[None, :] * stride_bn
|
|
)
|
|
|
|
if not has_zp and use_int4_w4a16:
|
|
b_zp_num = 8
|
|
if not has_zp and use_int8_w8a16:
|
|
b_zp_num = 128
|
|
elif has_zp and use_int4_w4a16:
|
|
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
|
|
|
# -----------------------------------------------------------
|
|
# Iterate to compute a block of the C matrix.
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
# of fp32 values for higher accuracy.
|
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
# Load the next block of A and B, generate a mask by checking the
|
|
# K dimension.
|
|
|
|
if not block_k_diviable:
|
|
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
|
k_other = 0.0
|
|
else:
|
|
k_mask = None
|
|
k_other = None
|
|
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
other=0.0,
|
|
)
|
|
b = tl.load(b_ptrs)
|
|
if use_int4_w4a16:
|
|
b = (b >> b_shifter) & 0xF
|
|
|
|
b_scale_ptrs = (
|
|
b_scale_ptr
|
|
+ off_experts * stride_bse
|
|
+ offs_bn[None, :] * stride_bsn
|
|
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
|
)
|
|
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
|
b_scale = b_scale.to(tl.float32)
|
|
|
|
if has_zp and use_int4_w4a16:
|
|
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
|
b_zp_ptrs = (
|
|
b_zp_ptr
|
|
+ off_experts * stride_bze
|
|
+ (offs_bn[None, :] // 2) * stride_bzn
|
|
+ offs_k_true * stride_bzk
|
|
)
|
|
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
|
b_zp = (b_zp >> b_zp_shifter) & 0xF
|
|
b_zp = b_zp.to(tl.float32)
|
|
elif has_zp and use_int8_w8a16:
|
|
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
|
b_zp_ptrs = (
|
|
b_zp_ptr
|
|
+ off_experts * stride_bze
|
|
+ offs_bn[None, :] * stride_bzn
|
|
+ offs_k_true * stride_bzk
|
|
)
|
|
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
|
b_zp = b_zp.to(tl.float32)
|
|
|
|
# We accumulate along the K dimension.
|
|
if has_zp:
|
|
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
|
else:
|
|
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
|
accumulator = tl.dot(a, b, acc=accumulator)
|
|
|
|
# Advance the ptrs to the next K block.
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
if use_int4_w4a16:
|
|
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
|
else:
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if MUL_ROUTED_WEIGHT:
|
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
|
accumulator = accumulator * moe_weight[:, None]
|
|
|
|
accumulator = accumulator.to(compute_type)
|
|
# -----------------------------------------------------------
|
|
# Write back the block of the output
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
|
|
|
|
@triton.jit
|
|
def fused_moe_kernel(
|
|
# Pointers to matrices
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
b_bias_ptr,
|
|
a_scale_ptr,
|
|
b_scale_ptr,
|
|
topk_weights_ptr,
|
|
sorted_token_ids_ptr,
|
|
expert_ids_ptr,
|
|
num_tokens_post_padded_ptr,
|
|
# Matrix dimensions
|
|
N,
|
|
K,
|
|
EM,
|
|
num_valid_tokens,
|
|
# The stride variables represent how much to increase the ptr by when
|
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
|
# how much to increase `a_ptr` by to get the element one row down
|
|
# (A has M rows).
|
|
stride_am,
|
|
stride_ak,
|
|
stride_be,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_asm,
|
|
stride_ask,
|
|
stride_bse,
|
|
stride_bsk,
|
|
stride_bsn,
|
|
stride_bbe, # bias expert stride
|
|
stride_bbn, # bias N stride
|
|
# Block size for block-wise quantization
|
|
group_n: tl.constexpr,
|
|
group_k: tl.constexpr,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
compute_type: tl.constexpr,
|
|
use_fp8_w8a8: tl.constexpr,
|
|
use_int8_w8a8: tl.constexpr,
|
|
use_int8_w8a16: tl.constexpr,
|
|
per_channel_quant: tl.constexpr,
|
|
HAS_BIAS: tl.constexpr,
|
|
):
|
|
"""
|
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
token and expert matrices.
|
|
|
|
Key Parameters:
|
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
|
be any shape representing batches and K is the feature dimension of
|
|
each token.
|
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
|
the number of experts, K is the input feature dimension, and N is
|
|
the output feature dimension.
|
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
|
total number of tokens post padding, topk is the number of times
|
|
each token is repeated, and N is the output feature dimension.
|
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
|
repeated topk times and arranged by the expert index they are
|
|
assigned to.
|
|
- expert_ids: A tensor containing the indices of the expert for each
|
|
block. It determines which expert matrix from B should be used for
|
|
each block in A.
|
|
This kernel performs the multiplication of a token by its corresponding
|
|
expert matrix as determined by `expert_ids`. The sorting of
|
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
|
multiplication across different blocks processed by the same expert.
|
|
"""
|
|
# -----------------------------------------------------------
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
# This is done in a grouped ordering to promote L2 data reuse.
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
# ----------------------------------------------------------
|
|
# Create pointers for the first blocks of A and B.
|
|
# We will advance this pointer as we move in the K direction
|
|
# and accumulate
|
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
|
return
|
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
|
token_mask = offs_token < num_valid_tokens
|
|
|
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
if off_experts == -1:
|
|
# -----------------------------------------------------------
|
|
# Write back zeros to the output when the expert is not
|
|
# in the current expert parallel rank.
|
|
write_zeros_to_output(
|
|
c_ptr,
|
|
stride_cm,
|
|
stride_cn,
|
|
pid_n,
|
|
N,
|
|
offs_token,
|
|
token_mask,
|
|
BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N,
|
|
compute_type,
|
|
)
|
|
return
|
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (
|
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
|
)
|
|
|
|
b_ptrs = (
|
|
b_ptr
|
|
+ off_experts * stride_be
|
|
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
)
|
|
if use_int8_w8a16:
|
|
b_scale_ptrs = (
|
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
|
)
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
|
|
if use_fp8_w8a8 or use_int8_w8a8:
|
|
# block-wise
|
|
if group_k > 0 and group_n > 0:
|
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
offs_bsn = offs_bn // group_n
|
|
b_scale_ptrs = (
|
|
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
|
)
|
|
# channel-wise
|
|
elif per_channel_quant:
|
|
b_scale_ptrs = (
|
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
|
)
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
# Load per-token scale for activations
|
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
|
# tensor-wise
|
|
else:
|
|
a_scale = tl.load(a_scale_ptr)
|
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
|
if HAS_BIAS:
|
|
# bias shape: [num_experts, N]
|
|
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
|
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
|
# -----------------------------------------------------------
|
|
# Iterate to compute a block of the C matrix.
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
# of fp32 values for higher accuracy.
|
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
# Load the next block of A and B, generate a mask by checking the
|
|
# K dimension.
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
other=0.0,
|
|
)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
# We accumulate along the K dimension.
|
|
if use_int8_w8a16:
|
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
k_start = k * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_scale = tl.load(
|
|
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
|
)
|
|
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
|
|
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
|
else:
|
|
if use_fp8_w8a8:
|
|
# acc used to enable fp8_fast_accum
|
|
accumulator = tl.dot(a, b, acc=accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b)
|
|
else:
|
|
accumulator += tl.dot(a, b)
|
|
# Advance the ptrs to the next K block.
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
if HAS_BIAS:
|
|
accumulator = accumulator + bias[None, :]
|
|
if MUL_ROUTED_WEIGHT:
|
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
|
accumulator = accumulator * moe_weight[:, None]
|
|
if use_int8_w8a16:
|
|
accumulator = (accumulator * b_scale).to(compute_type)
|
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
accumulator = accumulator.to(compute_type)
|
|
else:
|
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
|
else:
|
|
accumulator = accumulator.to(compute_type)
|
|
|
|
# -----------------------------------------------------------
|
|
# Write back the block of the output
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
|
|
|
|
def invoke_fused_moe_kernel(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
C: torch.Tensor,
|
|
A_scale: torch.Tensor | None,
|
|
B_scale: torch.Tensor | None,
|
|
B_zp: torch.Tensor | None,
|
|
topk_weights: torch.Tensor | None,
|
|
sorted_token_ids: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
num_tokens_post_padded: torch.Tensor,
|
|
mul_routed_weight: bool,
|
|
top_k: int,
|
|
config: dict[str, Any],
|
|
compute_type: tl.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
use_int4_w4a16: bool,
|
|
per_channel_quant: bool,
|
|
block_shape: list[int] | None = None,
|
|
B_bias: torch.Tensor | None = None,
|
|
) -> None:
|
|
assert topk_weights is not None or not mul_routed_weight
|
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
|
assert sorted_token_ids.stride(0) == 1
|
|
|
|
if use_fp8_w8a8 or use_int8_w8a8:
|
|
assert B_scale is not None
|
|
assert block_shape is None or triton.cdiv(
|
|
B.size(-2), block_shape[0]
|
|
) == B_scale.size(-2)
|
|
assert block_shape is None or triton.cdiv(
|
|
B.size(-1), block_shape[1]
|
|
) == B_scale.size(-1)
|
|
|
|
elif use_int8_w8a16 or use_int4_w4a16:
|
|
assert B_scale is not None
|
|
assert block_shape is None or block_shape[0] == 0
|
|
else:
|
|
assert A_scale is None
|
|
assert B_scale is None
|
|
|
|
M = A.size(0)
|
|
num_tokens = M * top_k
|
|
|
|
EM = sorted_token_ids.size(0)
|
|
if A.size(0) < config["BLOCK_SIZE_M"]:
|
|
# optimize for small batch_size.
|
|
# We assume that top_ids of each token is unique,
|
|
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
|
# and we can skip some invalid blocks.
|
|
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
|
grid = lambda META: (
|
|
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
|
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
|
)
|
|
HAS_BIAS = B_bias is not None
|
|
if (
|
|
(use_int8_w8a16 or use_int4_w4a16)
|
|
and block_shape is not None
|
|
and block_shape[1] > 0
|
|
):
|
|
assert B_scale is not None and B_scale.ndim == 3
|
|
assert B_zp is None or B_zp.ndim == 3
|
|
|
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
|
num_valid_tokens=num_tokens,
|
|
group_size=block_shape[1],
|
|
num_experts=B.size(0),
|
|
bit=4 if use_int4_w4a16 else 8,
|
|
)
|
|
config = config.copy()
|
|
config.update(
|
|
get_moe_wna16_block_config(
|
|
config=config,
|
|
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
|
num_valid_tokens=num_tokens,
|
|
size_k=A.size(1),
|
|
size_n=B.size(1),
|
|
num_experts=B.size(1),
|
|
group_size=block_shape[1],
|
|
real_top_k=top_k,
|
|
block_size_m=config["BLOCK_SIZE_M"],
|
|
)
|
|
)
|
|
|
|
if use_moe_wna16_cuda:
|
|
bit = 4 if use_int4_w4a16 else 8
|
|
ops.moe_wna16_gemm(
|
|
A,
|
|
C,
|
|
B,
|
|
B_scale,
|
|
B_zp,
|
|
topk_weights if mul_routed_weight else None,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
top_k,
|
|
config["BLOCK_SIZE_M"],
|
|
config["BLOCK_SIZE_N"],
|
|
config["BLOCK_SIZE_K"],
|
|
bit,
|
|
)
|
|
return
|
|
|
|
fused_moe_kernel_gptq_awq[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
B_scale,
|
|
B_zp,
|
|
topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
B.size(1),
|
|
A.size(1),
|
|
EM,
|
|
num_tokens,
|
|
A.stride(0),
|
|
A.stride(1),
|
|
B.stride(0),
|
|
B.stride(2),
|
|
B.stride(1),
|
|
C.stride(1),
|
|
C.stride(2),
|
|
B_scale.stride(0),
|
|
B_scale.stride(2),
|
|
B_scale.stride(1),
|
|
B_zp.stride(0) if B_zp is not None else 0,
|
|
B_zp.stride(2) if B_zp is not None else 0,
|
|
B_zp.stride(1) if B_zp is not None else 0,
|
|
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
|
group_size=block_shape[1],
|
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
top_k=top_k,
|
|
compute_type=compute_type,
|
|
has_zp=B_zp is not None,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
**config,
|
|
)
|
|
else:
|
|
config = config.copy()
|
|
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
|
if block_shape is not None:
|
|
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
|
fused_moe_kernel[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
B_bias,
|
|
A_scale,
|
|
B_scale,
|
|
topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
B.size(1),
|
|
B.size(2),
|
|
EM,
|
|
num_tokens,
|
|
A.stride(0),
|
|
A.stride(1),
|
|
B.stride(0),
|
|
B.stride(2),
|
|
B.stride(1),
|
|
C.stride(1),
|
|
C.stride(2),
|
|
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
|
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
B_bias.stride(0) if B_bias is not None else 0,
|
|
B_bias.stride(1) if B_bias is not None else 0,
|
|
0 if block_shape is None else block_shape[0],
|
|
0 if block_shape is None else block_shape[1],
|
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
top_k=top_k,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
per_channel_quant=per_channel_quant,
|
|
HAS_BIAS=HAS_BIAS,
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
**config,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def compute_identity_kernel(
|
|
top_k: int,
|
|
hidden_states_ptr: tl.tensor,
|
|
expert_scales_ptr: tl.tensor,
|
|
num_tokens: int,
|
|
output_ptr: tl.tensor,
|
|
hidden_dim: int,
|
|
scales_stride: int,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
) -> None:
|
|
pid = tl.program_id(0)
|
|
|
|
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
|
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
|
|
|
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
|
return
|
|
|
|
h = tl.load(
|
|
hidden_states_ptr
|
|
+ batch_id * hidden_dim
|
|
+ dim_offset
|
|
+ tl.arange(0, BLOCK_SIZE),
|
|
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
|
)
|
|
|
|
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
|
for i in range(top_k):
|
|
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
|
result += h * scale
|
|
|
|
tl.store(
|
|
output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
|
|
result,
|
|
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
|
)
|
|
|
|
|
|
def zero_experts_compute_triton(
|
|
expert_indices: torch.Tensor,
|
|
expert_scales: torch.Tensor,
|
|
num_experts: int,
|
|
zero_expert_type: str,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
N = expert_indices.numel()
|
|
top_k = expert_indices.size(-1)
|
|
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
|
|
|
|
if zero_expert_type == "identity":
|
|
zero_expert_mask = expert_indices < num_experts
|
|
zero_expert_scales = expert_scales.clone()
|
|
zero_expert_scales[zero_expert_mask] = 0.0
|
|
|
|
normal_expert_mask = expert_indices >= num_experts
|
|
expert_indices[normal_expert_mask] = 0
|
|
expert_scales[normal_expert_mask] = 0.0
|
|
|
|
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
|
hidden_dim = hidden_states.size(-1)
|
|
num_tokens = hidden_states.size(0)
|
|
|
|
grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
|
|
compute_identity_kernel[grid](
|
|
top_k,
|
|
hidden_states,
|
|
zero_expert_scales,
|
|
num_tokens,
|
|
output,
|
|
hidden_dim,
|
|
zero_expert_scales.stride(0),
|
|
BLOCK_SIZE=256,
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
|
def get_config_file_name(
|
|
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
|
|
) -> str:
|
|
device_name = current_platform.get_device_name().replace(" ", "_")
|
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
|
block_shape_selector = (
|
|
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
|
).replace(" ", "")
|
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
|
|
|
|
|
|
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
|
@functools.lru_cache
|
|
def get_moe_configs(
|
|
E: int,
|
|
N: int,
|
|
dtype: str | None,
|
|
block_n: int | None = None,
|
|
block_k: int | None = None,
|
|
) -> dict[int, Any] | None:
|
|
"""
|
|
Return optimized configurations for the fused MoE kernel.
|
|
|
|
The return value will be a dictionary that maps an irregular grid of
|
|
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
|
kernel on a given batch size bs, the closest batch size in the grid should
|
|
be picked and the associated configuration chosen to invoke the kernel.
|
|
"""
|
|
|
|
# Avoid optimizing for the batch invariant case. Use default config
|
|
if vllm_is_batch_invariant():
|
|
return None
|
|
|
|
# First look up if an optimized configuration is available in the configs
|
|
# directory
|
|
block_shape = [block_n, block_k] if block_n and block_k else None
|
|
json_file_name = get_config_file_name(E, N, dtype, block_shape)
|
|
|
|
config_file_paths = []
|
|
|
|
# note that we prioritize user defined config
|
|
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
|
if user_defined_config_folder is not None:
|
|
user_defined_config_file_path = os.path.join(
|
|
user_defined_config_folder, json_file_name
|
|
)
|
|
config_file_paths.append(user_defined_config_file_path)
|
|
|
|
default_config_file_path = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
|
)
|
|
config_file_paths.append(default_config_file_path)
|
|
|
|
for config_file_path in config_file_paths:
|
|
if os.path.exists(config_file_path):
|
|
with open(config_file_path) as f:
|
|
logger.info(
|
|
"Using configuration from %s for MoE layer.", config_file_path
|
|
)
|
|
# If a configuration has been found, return it
|
|
tuned_config = json.load(f)
|
|
# Delete triton_version from tuned_config
|
|
tuned_config.pop("triton_version", None)
|
|
return {int(key): val for key, val in tuned_config.items()}
|
|
|
|
# If no optimized configuration is available, we will use the default
|
|
# configuration
|
|
logger.warning(
|
|
(
|
|
"Using default MoE config. Performance might be sub-optimal! "
|
|
"Config file not found at %s"
|
|
),
|
|
config_file_paths,
|
|
)
|
|
return None
|
|
|
|
|
|
def get_moe_wna16_block_config(
|
|
config: dict[str, int],
|
|
use_moe_wna16_cuda: bool,
|
|
num_valid_tokens: int,
|
|
size_k: int,
|
|
size_n: int,
|
|
num_experts: int,
|
|
group_size: int,
|
|
real_top_k: int,
|
|
block_size_m: int,
|
|
):
|
|
if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
|
|
# optimal block config is set
|
|
return {}
|
|
if not use_moe_wna16_cuda:
|
|
# triton moe wna16 kernel
|
|
if num_valid_tokens // real_top_k == 1:
|
|
# if bs=1, use a smaller BLOCK_SIZE_N
|
|
return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
|
|
else:
|
|
return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
|
|
else:
|
|
# cuda moe wna16 kernel
|
|
# set default block_size 128, and increase them when num_blocks
|
|
# is too large.
|
|
block_size_n = 128
|
|
block_size_k = 128
|
|
if block_size_k <= group_size:
|
|
block_size_k = group_size
|
|
|
|
num_n_blocks = size_k // block_size_k
|
|
num_k_blocks = size_n // block_size_k
|
|
num_m_blocks = (
|
|
num_valid_tokens + block_size_m - 1
|
|
) / block_size_m + num_experts
|
|
if num_valid_tokens // real_top_k <= block_size_m:
|
|
num_m_blocks = min(num_m_blocks, num_valid_tokens)
|
|
num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
|
|
|
|
if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256:
|
|
block_size_k = 256
|
|
num_blocks = num_blocks // (256 // block_size_k)
|
|
|
|
if (
|
|
num_m_blocks <= 16
|
|
and size_k % (block_size_k * 2) == 0
|
|
and size_k % (block_size_k * 2) == 0
|
|
and block_size_k <= 512
|
|
and num_blocks >= 512
|
|
):
|
|
block_size_k = block_size_k * 2
|
|
num_blocks = num_blocks // 2
|
|
|
|
if num_blocks > 1024:
|
|
block_size_n = 256
|
|
num_n_blocks = num_n_blocks // 2
|
|
num_blocks = num_blocks // 2
|
|
|
|
if size_n <= 1024 and num_blocks >= 1024:
|
|
# The kernel performance got much better with BLOCK_SIZE_N=1024
|
|
# when num_blocks is large, event when N is small.
|
|
# Not sure why, maybe it force the CUDA SM process only one block
|
|
# at the same time.
|
|
block_size_n = 1024
|
|
|
|
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
|
|
|
|
|
|
def should_moe_wna16_use_cuda(
|
|
num_valid_tokens: int, group_size: int, num_experts: int, bit: int
|
|
):
|
|
return (
|
|
current_platform.is_cuda()
|
|
and bit == 4
|
|
and group_size in [32, 64, 128]
|
|
and num_valid_tokens / num_experts <= 6
|
|
)
|
|
|
|
|
|
def get_default_config(
|
|
M: int,
|
|
E: int,
|
|
N: int,
|
|
K: int,
|
|
topk: int,
|
|
dtype: str | None,
|
|
block_shape: list[int] | None = None,
|
|
) -> dict[str, int]:
|
|
if vllm_is_batch_invariant():
|
|
config = {
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_N": 64,
|
|
"BLOCK_SIZE_K": 32,
|
|
"GROUP_SIZE_M": 8,
|
|
}
|
|
return config
|
|
|
|
if dtype == "fp8_w8a8" and block_shape is not None:
|
|
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
|
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
|
# num_stages=3 can cause triton.runtime.errors.OutOfResources
|
|
# on ROCm, set it to 2 instead.
|
|
config = {
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_N": block_shape[0],
|
|
"BLOCK_SIZE_K": block_shape[1],
|
|
"GROUP_SIZE_M": 32,
|
|
"num_warps": 4,
|
|
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
|
}
|
|
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
|
# moe wna16 kernels
|
|
# only set BLOCK_SIZE_M
|
|
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
|
|
bit = 4 if dtype == "int4_w4a16" else 8
|
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
|
|
if use_moe_wna16_cuda:
|
|
config = {"BLOCK_SIZE_M": min(16, M)}
|
|
elif M <= 20:
|
|
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
|
elif M <= 40:
|
|
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
|
else:
|
|
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
|
elif M <= E:
|
|
config = {
|
|
"BLOCK_SIZE_M": 16,
|
|
"BLOCK_SIZE_N": 32,
|
|
"BLOCK_SIZE_K": 64,
|
|
"GROUP_SIZE_M": 1,
|
|
}
|
|
else:
|
|
config = {
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_N": 64,
|
|
"BLOCK_SIZE_K": 32,
|
|
"GROUP_SIZE_M": 8,
|
|
}
|
|
return config
|
|
|
|
|
|
def try_get_optimal_moe_config(
|
|
w1_shape: tuple[int, ...],
|
|
w2_shape: tuple[int, ...],
|
|
top_k: int,
|
|
dtype: str | None,
|
|
M: int,
|
|
block_shape: list[int] | None = None,
|
|
) -> dict[str, int]:
|
|
from vllm.model_executor.layers.fused_moe import get_config
|
|
|
|
override_config = get_config()
|
|
if override_config:
|
|
config = override_config
|
|
else:
|
|
# First try to load optimal config from the file
|
|
E, _, N = w2_shape
|
|
if dtype == "int4_w4a16":
|
|
N = N * 2
|
|
block_n = block_shape[0] if block_shape else 0
|
|
block_k = block_shape[1] if block_shape else 0
|
|
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
|
|
|
if configs:
|
|
# If an optimal configuration map has been found, look up the
|
|
# optimal config
|
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
else:
|
|
# Else use the default config
|
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
|
|
return config
|
|
|
|
|
|
def vllm_topk_softmax(
|
|
topk_weights: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
ops.topk_softmax(
|
|
topk_weights,
|
|
topk_indices,
|
|
token_expert_indices,
|
|
gating_output,
|
|
)
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
return topk_weights, topk_indices
|
|
|
|
|
|
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
|
if is_rocm_aiter_moe_enabled():
|
|
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
|
|
|
return rocm_aiter_topk_softmax
|
|
return vllm_topk_softmax
|
|
|
|
|
|
def fused_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
indices_type: torch.dtype | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
|
|
|
M, _ = hidden_states.size()
|
|
|
|
topk_weights = torch.empty(
|
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
topk_ids = torch.empty(
|
|
M,
|
|
topk,
|
|
dtype=torch.int32 if indices_type is None else indices_type,
|
|
device=hidden_states.device,
|
|
)
|
|
token_expert_indices = torch.empty(
|
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
|
)
|
|
|
|
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
|
|
|
topk_func = dispatch_topk_func()
|
|
topk_weights, topk_ids = topk_func(
|
|
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
|
)
|
|
|
|
return topk_weights, topk_ids, token_expert_indices
|
|
|
|
|
|
def fused_topk_bias(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
e_score_correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
):
|
|
n_routed_experts = gating_output.shape[-1]
|
|
scores = gating_output.softmax(dim=-1)
|
|
scores_for_choice = scores.view(
|
|
-1, n_routed_experts
|
|
) + e_score_correction_bias.unsqueeze(0)
|
|
|
|
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
|
use_sorted = vllm_is_batch_invariant()
|
|
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
|
topk_weights = scores.gather(1, topk_indices)
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
|
|
|
|
|
|
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
|
def grouped_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if (
|
|
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
|
and current_platform.is_cuda()
|
|
and num_expert_group <= 32
|
|
and topk <= 32
|
|
and e_score_correction_bias is not None
|
|
):
|
|
return fused_grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=gating_output,
|
|
topk=topk,
|
|
renormalize=renormalize,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
scoring_func=scoring_func,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
)
|
|
|
|
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
|
|
|
if scoring_func == "softmax":
|
|
scores = torch.softmax(gating_output, dim=-1)
|
|
elif scoring_func == "sigmoid":
|
|
scores = gating_output.sigmoid()
|
|
else:
|
|
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
|
|
|
num_token = scores.size(0)
|
|
if e_score_correction_bias is not None:
|
|
# Store original scores before applying correction bias. We use biased
|
|
# scores for expert selection but original scores for routing weights
|
|
original_scores = scores
|
|
scores = scores + e_score_correction_bias.unsqueeze(0)
|
|
group_scores = (
|
|
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
|
)
|
|
else:
|
|
group_scores = (
|
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
) # [n, n_group]
|
|
|
|
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
|
use_sorted = vllm_is_batch_invariant()
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
|
1
|
|
] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
|
.reshape(num_token, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
|
|
|
if e_score_correction_bias is not None:
|
|
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
|
# Use original unbiased scores for the routing weights
|
|
topk_weights = original_scores.gather(1, topk_ids)
|
|
else:
|
|
topk_weights, topk_ids = torch.topk(
|
|
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
|
)
|
|
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
if routed_scaling_factor != 1.0:
|
|
topk_weights = topk_weights * routed_scaling_factor
|
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
|
|
|
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
|
def eplb_map_to_physical_and_record(
|
|
topk_ids: torch.Tensor,
|
|
expert_load_view: torch.Tensor,
|
|
logical_to_physical_map: torch.Tensor,
|
|
logical_replica_count: torch.Tensor,
|
|
indices_type: torch.dtype | None = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Map the logical expert ids to physical expert ids
|
|
and record the expert load metrics.
|
|
|
|
This will select a pseudo-random replica for each logical expert.
|
|
Only used for EPLB.
|
|
|
|
Args:
|
|
topk_ids: The logical expert ids.
|
|
expert_load_view: The expert load view.
|
|
logical_to_physical_map: The logical to physical map.
|
|
logical_replica_count: The logical replica count.
|
|
indices_type: The indices type.
|
|
|
|
Returns:
|
|
The physical expert ids.
|
|
"""
|
|
|
|
# 1. Convert the logical expert ids to physical expert ids
|
|
# Directly select a random replica for each logical expert
|
|
|
|
# In case `indices_type` is not `torch.long` or `torch.int`,
|
|
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
|
topk_ids_long = topk_ids.long()
|
|
# Use (token position) modulo (replica count)
|
|
# to deterministically choose a replica
|
|
replica_count = logical_replica_count[topk_ids_long]
|
|
# Flatten-position based index, reshaped back to `topk_ids` shape
|
|
pos_indices = torch.arange(
|
|
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
|
|
).reshape_as(topk_ids)
|
|
# Compute pseudo-random indices by modulo
|
|
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
|
|
physical_ids = (
|
|
logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1)
|
|
)
|
|
|
|
topk_ids = physical_ids
|
|
|
|
# 2. Record expert load metrics.
|
|
|
|
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
|
# can be done in a more unified way, since
|
|
# `FusedMoEPrepareAndFinalize` will return the expert
|
|
# token count, in some cases directly from the kernel.
|
|
# However, now there are many code paths not using
|
|
# the modular kernel, e.g. calling `fused_experts`,
|
|
# so we decide to keep the logic here.
|
|
#
|
|
# If later refactor moved all the MoE kernel calls
|
|
# to the modular kernel, we can move this logic there
|
|
# to achieve better efficiency.
|
|
|
|
# `expert_load_view`: (num_physical_experts,)
|
|
|
|
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
|
|
topk_ids_flatten = topk_ids.flatten()
|
|
expert_load_view.scatter_add_(
|
|
dim=0,
|
|
index=topk_ids_flatten.long(),
|
|
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
|
|
)
|
|
|
|
if indices_type is not None:
|
|
topk_ids = topk_ids.to(dtype=indices_type)
|
|
return topk_ids
|
|
|
|
|
|
def fused_grouped_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
e_score_correction_bias: torch.Tensor,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
|
|
|
if scoring_func == "softmax":
|
|
scores = torch.softmax(gating_output, dim=-1)
|
|
elif scoring_func == "sigmoid":
|
|
scores = gating_output.sigmoid()
|
|
else:
|
|
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
|
|
|
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
|
topk_values, topk_indices = ops.grouped_topk(
|
|
scores,
|
|
scores_with_bias.to(scores.dtype),
|
|
num_expert_group,
|
|
topk_group,
|
|
topk,
|
|
renormalize,
|
|
routed_scaling_factor,
|
|
)
|
|
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
|
|
|
|
|
def inplace_fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
ocp_mx_scheme: str | None = None,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
w1_zp: torch.Tensor | None = None,
|
|
w2_zp: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
block_shape: list[int] | None = None,
|
|
w1_bias: torch.Tensor | None = None,
|
|
w2_bias: torch.Tensor | None = None,
|
|
) -> None:
|
|
fused_experts_impl(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
True,
|
|
activation,
|
|
apply_router_weight_on_input,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a8,
|
|
use_int8_w8a16,
|
|
use_int4_w4a16,
|
|
ocp_mx_scheme,
|
|
per_channel_quant,
|
|
global_num_experts,
|
|
expert_map,
|
|
w1_scale,
|
|
w2_scale,
|
|
w1_zp,
|
|
w2_zp,
|
|
a1_scale,
|
|
a2_scale,
|
|
block_shape,
|
|
w1_bias,
|
|
w2_bias,
|
|
)
|
|
|
|
|
|
def inplace_fused_experts_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
ocp_mx_scheme: str | None = None,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
w1_zp: torch.Tensor | None = None,
|
|
w2_zp: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
block_shape: list[int] | None = None,
|
|
w1_bias: torch.Tensor | None = None,
|
|
w2_bias: torch.Tensor | None = None,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="inplace_fused_experts",
|
|
op_func=inplace_fused_experts,
|
|
mutates_args=["hidden_states"],
|
|
fake_impl=inplace_fused_experts_fake,
|
|
tags=(
|
|
()
|
|
if is_torch_equal_or_newer("2.7.0")
|
|
else (torch.Tag.needs_fixed_stride_order,)
|
|
),
|
|
)
|
|
|
|
|
|
def outplace_fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
ocp_mx_scheme: str | None = None,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
w1_zp: torch.Tensor | None = None,
|
|
w2_zp: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
block_shape: list[int] | None = None,
|
|
w1_bias: torch.Tensor | None = None,
|
|
w2_bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return fused_experts_impl(
|
|
hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
False,
|
|
activation,
|
|
apply_router_weight_on_input,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a8,
|
|
use_int8_w8a16,
|
|
use_int4_w4a16,
|
|
ocp_mx_scheme,
|
|
per_channel_quant,
|
|
global_num_experts,
|
|
expert_map,
|
|
w1_scale,
|
|
w2_scale,
|
|
w1_zp,
|
|
w2_zp,
|
|
a1_scale,
|
|
a2_scale,
|
|
block_shape,
|
|
w1_bias,
|
|
w2_bias,
|
|
)
|
|
|
|
|
|
def outplace_fused_experts_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
ocp_mx_scheme: str | None = None,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
w1_zp: torch.Tensor | None = None,
|
|
w2_zp: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
block_shape: list[int] | None = None,
|
|
w1_bias: torch.Tensor | None = None,
|
|
w2_bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="outplace_fused_experts",
|
|
op_func=outplace_fused_experts,
|
|
fake_impl=outplace_fused_experts_fake,
|
|
tags=(
|
|
()
|
|
if is_torch_equal_or_newer("2.7.0")
|
|
else (torch.Tag.needs_fixed_stride_order,)
|
|
),
|
|
)
|
|
|
|
|
|
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
|
torch.ops.vllm.inplace_fused_experts(**kwargs)
|
|
hidden_states = kwargs["hidden_states"]
|
|
return hidden_states
|
|
|
|
|
|
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
|
return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
|
|
|
|
|
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
|
if inplace and not disable_inplace():
|
|
return torch_vllm_inplace_fused_experts
|
|
return torch_vllm_outplace_fused_experts
|
|
|
|
|
|
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
|
# torch ops.
|
|
def fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
inplace: bool = False,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
quant_config: FusedMoEQuantConfig | None = None,
|
|
allow_deep_gemm: bool = False,
|
|
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
|
) -> torch.Tensor:
|
|
if quant_config is None:
|
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
|
use_fp8_w8a8 = quant_config.use_fp8_w8a8
|
|
|
|
# For now, disable DeepGemm for small N (<= 512) until better
|
|
# permute/unpermute ops are available.
|
|
# However, on B200, we use DeepGemm for all cases because they only support
|
|
# E8M0 scale, which means we requantize the weight and input to the specific
|
|
# scale. Fallen back to cutlass or triton for some cases would cause
|
|
# accuracy issue.
|
|
if (
|
|
allow_deep_gemm
|
|
and quant_config.use_fp8_w8a8
|
|
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
|
|
):
|
|
assert quant_config is not None
|
|
assert apply_router_weight_on_input is False
|
|
return deep_gemm_moe_fp8(
|
|
hidden_states=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=inplace,
|
|
activation=activation,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
w1_scale=quant_config.w1_scale,
|
|
w2_scale=quant_config.w2_scale,
|
|
a1_scale=quant_config.a1_scale,
|
|
a2_scale=quant_config.a2_scale,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
)
|
|
elif (
|
|
allow_cutlass_block_scaled_grouped_gemm
|
|
and use_fp8_w8a8
|
|
and _valid_cutlass_block_scaled_grouped_gemm(
|
|
w1, w2, inplace, activation, apply_router_weight_on_input, expert_map
|
|
)
|
|
):
|
|
assert quant_config is not None
|
|
return run_cutlass_block_scaled_fused_experts(
|
|
a=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
w1_scale=quant_config.w1_scale,
|
|
w2_scale=quant_config.w2_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
)
|
|
else:
|
|
return dispatch_fused_experts_func(inplace)(
|
|
hidden_states=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
use_fp8_w8a8=quant_config.use_fp8_w8a8,
|
|
use_int8_w8a8=quant_config.use_int8_w8a8,
|
|
use_int8_w8a16=quant_config.use_int8_w8a16,
|
|
use_int4_w4a16=quant_config.use_int4_w4a16,
|
|
ocp_mx_scheme=quant_config.ocp_mx_scheme,
|
|
per_channel_quant=quant_config.per_act_token_quant,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
w1_scale=quant_config.w1_scale,
|
|
w2_scale=quant_config.w2_scale,
|
|
w1_zp=quant_config.w1_zp,
|
|
w2_zp=quant_config.w2_zp,
|
|
a1_scale=quant_config.a1_scale,
|
|
a2_scale=quant_config.a2_scale,
|
|
block_shape=quant_config.block_shape,
|
|
w1_bias=quant_config.w1_bias,
|
|
w2_bias=quant_config.w2_bias,
|
|
)
|
|
|
|
|
|
SILU_NO_MUL: str = activation_without_mul("silu")
|
|
GELU_NO_MUL: str = activation_without_mul("gelu")
|
|
|
|
|
|
def _get_config_quant_dtype(
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a8: bool,
|
|
ocp_mx_scheme: str | None,
|
|
) -> None | torch.dtype | str:
|
|
"""
|
|
Get the quantization type based on the quantization strategy flags.
|
|
We don't have a quant_config at this point so we need to work backwards.
|
|
A return type of None means no quantization is required because the
|
|
input is unquantized or has been quantized prior to calling
|
|
fused_experts_impl.
|
|
"""
|
|
if use_fp8_w8a8:
|
|
return torch.float8_e4m3fn
|
|
elif use_int8_w8a8:
|
|
return torch.int8
|
|
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
|
|
return "mxfp4"
|
|
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
|
|
return "mxfp6_e3m2"
|
|
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
|
|
return "mxfp6_e2m3"
|
|
return None
|
|
|
|
|
|
def fused_experts_impl(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
inplace: bool = False,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
ocp_mx_scheme: str | None = None,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
w1_scale: torch.Tensor | None = None,
|
|
w2_scale: torch.Tensor | None = None,
|
|
w1_zp: torch.Tensor | None = None,
|
|
w2_zp: torch.Tensor | None = None,
|
|
a1_scale: torch.Tensor | None = None,
|
|
a2_scale: torch.Tensor | None = None,
|
|
block_shape: list[int] | None = None,
|
|
w1_bias: torch.Tensor | None = None,
|
|
w2_bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# Check constraints.
|
|
if use_int4_w4a16:
|
|
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
|
elif ocp_mx_scheme is not None:
|
|
if ocp_mx_scheme in {
|
|
"w_mxfp4_a_mxfp4",
|
|
"w_mxfp4_a_mxfp6_e3m2",
|
|
"w_mxfp4_a_mxfp6_e2m3",
|
|
}:
|
|
# 16bit activation and fp4x2 packed weight
|
|
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
|
|
elif ocp_mx_scheme in {
|
|
"w_mxfp6_e3m2_a_mxfp6_e3m2",
|
|
"w_mxfp6_e2m3_a_mxfp6_e2m3",
|
|
}:
|
|
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
|
|
"hidden size mismatch"
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
|
else:
|
|
assert hidden_states.size(1) == w1.size(2), (
|
|
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
|
|
)
|
|
|
|
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
|
|
num_tokens = hidden_states.size(0)
|
|
E, N, _ = w1.size()
|
|
K = w2.size(1)
|
|
if global_num_experts == -1:
|
|
global_num_experts = E
|
|
top_k_num = topk_ids.size(1)
|
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
|
# https://github.com/vllm-project/vllm/issues/5938
|
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
M = min(num_tokens, CHUNK_SIZE)
|
|
|
|
config_dtype = _get_config_dtype_str(
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
ocp_mx_scheme=ocp_mx_scheme,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
|
|
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
|
# quantized prior to calling fused_experts.
|
|
quant_dtype = _get_config_quant_dtype(
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
ocp_mx_scheme=ocp_mx_scheme,
|
|
)
|
|
|
|
get_config_func = functools.partial(
|
|
try_get_optimal_moe_config,
|
|
w1.size(),
|
|
w2.size(),
|
|
top_k_num,
|
|
config_dtype,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
config = get_config_func(M)
|
|
|
|
# We can reuse the memory between these because by the time we need
|
|
# cache3, we're done with cache1
|
|
cache13 = torch.empty(
|
|
M * top_k_num * max(N, K),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
|
|
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
|
|
|
# This needs separate memory since it's used concurrently with cache1
|
|
intermediate_cache2 = torch.empty(
|
|
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
|
|
)
|
|
|
|
if hidden_states.dtype == torch.bfloat16:
|
|
compute_type = tl.bfloat16
|
|
elif hidden_states.dtype == torch.float16:
|
|
compute_type = tl.float16
|
|
elif hidden_states.dtype == torch.float32:
|
|
compute_type = tl.float32
|
|
else:
|
|
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
|
|
|
if inplace and not disable_inplace():
|
|
out_hidden_states = hidden_states
|
|
else:
|
|
out_hidden_states = torch.empty_like(hidden_states)
|
|
|
|
if ocp_mx_scheme is not None:
|
|
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
|
# and for which we have a native OCP mx fused MOE kernel,
|
|
# this dequantization step should not be done.
|
|
if ocp_mx_scheme in {
|
|
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
|
|
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
|
|
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
|
|
}:
|
|
# Weight has to be dequantized for mxfp4 emulation.
|
|
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
|
w1_scale = None
|
|
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
|
w2_scale = None
|
|
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
|
|
w1 = dequant_mxfp6(
|
|
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
|
)
|
|
w1_scale = None
|
|
w2 = dequant_mxfp6(
|
|
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
|
)
|
|
w2_scale = None
|
|
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
|
|
w1 = dequant_mxfp6(
|
|
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
|
)
|
|
w1_scale = None
|
|
w2 = dequant_mxfp6(
|
|
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
|
)
|
|
w2_scale = None
|
|
else:
|
|
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
|
|
|
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
|
begin_chunk_idx, end_chunk_idx = (
|
|
chunk * CHUNK_SIZE,
|
|
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
|
)
|
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
|
tokens_in_chunk, _ = curr_hidden_states.size()
|
|
|
|
if tokens_in_chunk == 0:
|
|
break
|
|
|
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
|
# Adjust the intermediate cache size and config for the last
|
|
# chunk. Note that in most cases we only have one chunk
|
|
# so the cache size and config are already set correctly and
|
|
# do not need to be adjusted.
|
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
|
intermediate_cache2 = intermediate_cache2[
|
|
: tokens_in_chunk * topk_ids.size(1)
|
|
]
|
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
|
config = get_config_func(tokens_in_chunk)
|
|
|
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
|
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
|
A=curr_hidden_states,
|
|
A_scale=a1_scale,
|
|
quant_dtype=quant_dtype,
|
|
per_act_token_quant=per_channel_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
|
)
|
|
|
|
invoke_fused_moe_kernel(
|
|
qcurr_hidden_states,
|
|
w1,
|
|
intermediate_cache1,
|
|
a1q_scale,
|
|
w1_scale,
|
|
w1_zp,
|
|
curr_topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
apply_router_weight_on_input,
|
|
top_k_num,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_channel_quant=per_channel_quant,
|
|
block_shape=block_shape,
|
|
B_bias=w1_bias,
|
|
)
|
|
|
|
# Activation function with multiplication
|
|
if activation == "silu":
|
|
torch.ops._C.silu_and_mul(
|
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
)
|
|
elif activation == "gelu":
|
|
torch.ops._C.gelu_and_mul(
|
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
)
|
|
elif activation == "swigluoai":
|
|
# alpha = 1.702, limit = 7.0
|
|
torch.ops._C.swigluoai_and_mul(
|
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
)
|
|
# Activation function without multiplication
|
|
elif activation == SILU_NO_MUL:
|
|
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
|
elif activation == GELU_NO_MUL:
|
|
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
|
|
|
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
|
A=intermediate_cache2,
|
|
A_scale=a2_scale,
|
|
quant_dtype=quant_dtype,
|
|
per_act_token_quant=per_channel_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
invoke_fused_moe_kernel(
|
|
qintermediate_cache2,
|
|
w2,
|
|
intermediate_cache3,
|
|
a2q_scale,
|
|
w2_scale,
|
|
w2_zp,
|
|
curr_topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
not apply_router_weight_on_input,
|
|
1,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_channel_quant=per_channel_quant,
|
|
block_shape=block_shape,
|
|
B_bias=w2_bias,
|
|
)
|
|
|
|
ops.moe_sum(
|
|
intermediate_cache3.view(*intermediate_cache3.size()),
|
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
|
)
|
|
|
|
return out_hidden_states
|
|
|
|
|
|
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|
def __init__(
|
|
self,
|
|
quant_config: FusedMoEQuantConfig,
|
|
):
|
|
super().__init__(quant_config)
|
|
|
|
@property
|
|
def activation_formats(
|
|
self,
|
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
|
return (
|
|
mk.FusedMoEActivationFormat.Standard,
|
|
mk.FusedMoEActivationFormat.Standard,
|
|
)
|
|
|
|
def supports_chunking(self) -> bool:
|
|
return True
|
|
|
|
def supports_expert_map(self) -> bool:
|
|
return True
|
|
|
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
|
return TopKWeightAndReduceNoOP()
|
|
|
|
def workspace_shapes(
|
|
self,
|
|
M: int,
|
|
N: int,
|
|
K: int,
|
|
topk: int,
|
|
global_num_experts: int,
|
|
local_num_experts: int,
|
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
workspace1 = (M, topk, max(N // 2, K))
|
|
workspace2 = (M, topk, max(N, K))
|
|
output = (M, K)
|
|
return (workspace1, workspace2, output)
|
|
|
|
def apply(
|
|
self,
|
|
output: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str,
|
|
global_num_experts: int,
|
|
expert_map: torch.Tensor | None,
|
|
a1q_scale: torch.Tensor | None,
|
|
a2_scale: torch.Tensor | None,
|
|
workspace13: torch.Tensor,
|
|
workspace2: torch.Tensor,
|
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
|
apply_router_weight_on_input: bool,
|
|
):
|
|
# Check constraints.
|
|
if self.quant_config.use_int4_w4a16:
|
|
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
|
|
else:
|
|
assert hidden_states.size(-1) == w1.size(2), (
|
|
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
|
|
)
|
|
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
assert hidden_states.dim() == 2
|
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert hidden_states.dtype in [
|
|
torch.float32,
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
torch.float8_e4m3fn,
|
|
]
|
|
|
|
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
|
hidden_states, w1, w2, topk_ids
|
|
)
|
|
|
|
if global_num_experts == -1:
|
|
global_num_experts = E
|
|
|
|
config = try_get_optimal_moe_config(
|
|
w1.size(),
|
|
w2.size(),
|
|
top_k_num,
|
|
self.quant_config.config_name(hidden_states.dtype),
|
|
num_tokens,
|
|
block_shape=self.block_shape,
|
|
)
|
|
|
|
if hidden_states.dtype == torch.bfloat16:
|
|
compute_type = tl.bfloat16
|
|
elif hidden_states.dtype == torch.float16:
|
|
compute_type = tl.float16
|
|
elif hidden_states.dtype == torch.float32:
|
|
compute_type = tl.float32
|
|
elif hidden_states.dtype == torch.float8_e4m3fn:
|
|
compute_type = tl.bfloat16
|
|
else:
|
|
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
|
|
|
# Note that the output tensor might be in workspace1
|
|
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
|
intermediate_cache2 = _resize_cache(
|
|
workspace13, (num_tokens * top_k_num, N // 2)
|
|
)
|
|
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
|
|
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
|
)
|
|
|
|
invoke_fused_moe_kernel(
|
|
hidden_states,
|
|
w1,
|
|
intermediate_cache1,
|
|
a1q_scale,
|
|
self.w1_scale,
|
|
self.w1_zp,
|
|
None, # topk_weights
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
False, # mul_routed_weights
|
|
top_k_num,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
|
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
|
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
|
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
|
per_channel_quant=self.per_act_token_quant,
|
|
block_shape=self.block_shape,
|
|
B_bias=self.w1_bias,
|
|
)
|
|
|
|
self.activation(
|
|
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
)
|
|
|
|
a2q_scale: torch.Tensor | None = None
|
|
|
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
|
intermediate_cache2,
|
|
a2_scale,
|
|
self.quant_dtype,
|
|
self.per_act_token_quant,
|
|
self.block_shape,
|
|
)
|
|
|
|
invoke_fused_moe_kernel(
|
|
qintermediate_cache2,
|
|
w2,
|
|
intermediate_cache3,
|
|
a2q_scale,
|
|
self.w2_scale,
|
|
self.w2_zp,
|
|
topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
not apply_router_weight_on_input,
|
|
1,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
|
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
|
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
|
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
|
per_channel_quant=self.per_act_token_quant,
|
|
block_shape=self.block_shape,
|
|
B_bias=self.w2_bias,
|
|
)
|
|
|
|
ops.moe_sum(intermediate_cache3, output)
|
|
|
|
|
|
def modular_triton_fused_moe(
|
|
quant_config: FusedMoEQuantConfig,
|
|
) -> mk.FusedMoEModularKernel:
|
|
return mk.FusedMoEModularKernel(
|
|
MoEPrepareAndFinalizeNoEP(),
|
|
TritonExperts(quant_config),
|
|
)
|