[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-10-16 15:53:11 -04:00
committed by GitHub
parent 2ed8b6b3d0
commit fb0571b077
12 changed files with 1174 additions and 335 deletions

View File

@@ -8,12 +8,77 @@
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace moe {
namespace batched_moe_align_block_size {
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
static constexpr int32_t num_threads = 1024;
static constexpr int32_t num_blocks = 1;
__global__ void batched_moe_align_block_size_kernel(
int32_t const num_batches, int32_t const max_tokens_per_batch,
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
int32_t* __restrict__ num_tokens_post_pad) {
// TODO(varun): This is a naive implementation. Could be optimized.
size_t const batch_id = threadIdx.x;
size_t const stride = blockDim.x * gridDim.x;
int32_t const num_blocks_per_batch =
CEILDIV(max_tokens_per_batch, block_size);
int32_t const sorted_ids_size =
num_blocks_per_batch * num_batches * block_size;
int32_t const block_ids_size = sorted_ids_size / block_size;
int32_t const SENTINEL =
num_batches * max_tokens_per_batch; // To denote invalid entries.
// Intialize sorted_ids
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
sorted_ids[i] = SENTINEL;
}
// Intialize expert_ids with -1
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
block_ids[i] = -1;
}
int32_t b_num_tokens = 0;
if (batch_id < num_batches) {
b_num_tokens = batch_num_tokens[batch_id];
}
int32_t const ceil_b_num_tokens =
CEILDIV(b_num_tokens, block_size) * block_size;
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
__syncthreads();
bool const is_last_batch = batch_id == (num_batches - 1);
if (is_last_batch) {
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
}
if (batch_id < num_batches) {
int32_t const batch_offset = batch_id * max_tokens_per_batch;
for (size_t i = 0; i < b_num_tokens; ++i) {
sorted_ids[cumsum_val + i] = batch_offset + i;
}
int32_t const block_start = cumsum_val / block_size;
int32_t const num_blocks = ceil_b_num_tokens / block_size;
for (size_t i = 0; i < num_blocks; ++i) {
block_ids[block_start + i] = batch_id;
}
}
}
} // namespace batched_moe_align_block_size
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
});
}
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& batch_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor batch_ids,
torch::Tensor num_tokens_post_pad) {
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t const B = batch_num_tokens.size(0);
int32_t const num_blocks_per_batch =
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
int32_t const num_blocks = num_blocks_per_batch * B;
int64_t const sorted_ids_size = num_blocks * block_size;
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
TORCH_CHECK(B <= batched_kernel::num_threads);
batched_kernel::batched_moe_align_block_size_kernel<<<
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>());
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{

View File

@@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& expert_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@@ -22,6 +22,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size, but for the batched case.
m.def(
"batched_moe_align_block_size(int max_tokens_per_batch,"
" int block_size, Tensor expert_num_tokens,"
" Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "

View File

@@ -92,8 +92,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| marlin experts | standard,</br>batched | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
@@ -115,6 +115,6 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

View File

@@ -7,6 +7,8 @@ Run `pytest tests/kernels/test_moe.py`.
import functools
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import pytest
import torch
@@ -26,7 +28,10 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
batched_fused_marlin_moe,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
@@ -564,6 +569,105 @@ def marlin_moe_generate_valid_test_cases():
return cases
@dataclass
class MarlinMoEWeightData:
w_ref: torch.Tensor
qweight: torch.Tensor
scales: torch.Tensor
global_scale: torch.Tensor | None
g_idx: torch.Tensor | None
zeros: torch.Tensor | None
sort_indices: torch.Tensor | None
marlin_bias: torch.Tensor | None
@staticmethod
def make(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool | None = None,
bias: torch.Tensor | None = None,
) -> "MarlinMoEWeightData":
assert w.ndim == 3
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
k = w.shape[-1]
w_ref_l: list[torch.Tensor] = []
qweight_l: list[torch.Tensor] = []
scales_l: list[torch.Tensor] = []
global_scale_l: list[torch.Tensor] = []
zeros_l: list[torch.Tensor] = []
g_idx_l: list[torch.Tensor] = []
sort_indices_l: list[torch.Tensor] = []
bias_l: list[torch.Tensor] = []
for i in range(w.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref, qweight, scales, global_scale = (
rand_marlin_weight_nvfp4_like(w[i], group_size)
)
else:
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
w[i], group_size
)
global_scale = None
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
if global_scale is not None:
global_scale_l.append(global_scale)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
elif has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size
)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
if bias is not None:
bias_l.append(marlin_permute_bias(bias[i]))
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l)
global_scale = stack_and_dev(global_scale_l) if global_scale_l else None
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
marlin_bias = stack_and_dev(bias_l) if bias_l else None
return MarlinMoEWeightData(
w_ref=w_ref,
qweight=qweight,
scales=scales,
global_scale=global_scale,
g_idx=g_idx,
zeros=zeros,
sort_indices=sort_indices,
marlin_bias=marlin_bias,
)
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
@@ -584,7 +688,6 @@ def test_fused_marlin_moe(
is_k_full: bool,
):
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
@@ -600,152 +703,44 @@ def test_fused_marlin_moe(
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
global_scale1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
)
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = (
rand_marlin_weight_nvfp4_like(w1[i], group_size)
)
else:
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
w1[i], group_size
)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
global_scale2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = (
rand_marlin_weight_nvfp4_like(w2[i], group_size)
)
else:
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
w2[i], group_size
)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
torch_output = torch_moe(
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
w1_data.qweight,
w2_data.qweight,
None,
None,
scales1,
scales2,
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
)
@@ -773,92 +768,52 @@ def test_fused_marlin_moe_with_bias(m):
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
w1_data = MarlinMoEWeightData.make(
w=w1,
quant_type=quant_type,
group_size=group_size,
act_order=act_order,
bias=b_bias1,
)
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
w2_data = MarlinMoEWeightData.make(
w=w2,
quant_type=quant_type,
group_size=group_size,
act_order=act_order,
bias=b_bias2,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
torch_output = torch_moe(
a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
w1_data.qweight,
w2_data.qweight,
w1_data.marlin_bias,
w2_data.marlin_bias,
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
)
@@ -895,6 +850,41 @@ def test_moe_align_block_size_opcheck():
)
def test_batched_moe_align_block_size_opcheck():
max_tokens_per_batch = 512
num_experts = 4
block_size = 16
expert_num_tokens = torch.randint(
low=0,
high=max_tokens_per_batch,
size=(num_experts,),
dtype=torch.int32,
device="cuda",
)
max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
assert max_num_tokens_padded % block_size == 0
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda")
opcheck(
torch.ops._moe_C.batched_moe_align_block_size,
(
max_tokens_per_batch,
block_size,
expert_num_tokens,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@@ -979,3 +969,171 @@ def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
else:
atol = 5e-2
torch.testing.assert_close(out, ref, atol=atol, rtol=0)
@pytest.mark.parametrize("m", [16, 32, 64])
@pytest.mark.parametrize("n", [128])
@pytest.mark.parametrize("k", [128])
@pytest.mark.parametrize("e", [8, 12, 16, 32])
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_batched_fused_marlin_moe(
m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int
):
print(
f"testing m={m}, n={n}, k={k}, e={e}, "
f"topk={topk}, "
f"max_tokens_per_batch={max_tokens_per_batch}"
)
torch.cuda.manual_seed(0)
dtype = torch.bfloat16
quant_dtype = scalar_types.float4_e2m1f
group_size = 32
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None
)
w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
class BatchedRun:
@staticmethod
def _make_expert_num_tokens_cpu(
e: int, # num_experts
topk_ids_cpu: torch.Tensor,
) -> torch.Tensor:
expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu")
for topk_id in torch.flatten(topk_ids_cpu):
expert_num_tokens_cpu[topk_id] += 1
return expert_num_tokens_cpu
def __init__(
self,
max_tokens_per_batch: int,
num_experts: int,
_topk_ids: torch.Tensor,
_topk_weights: torch.Tensor,
):
self.max_tokens_per_batch = max_tokens_per_batch
self.e = num_experts
self.topk_ids_cpu = _topk_ids.to("cpu")
self.topk_weights_cpu = _topk_weights.to("cpu")
self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu(
self.e, self.topk_ids_cpu
)
def is_valid(self):
"""
Return True only if the input can be represented in a Batched
format.
"""
return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch)
def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states_cpu = hidden_states.to("cpu")
K = hidden_states_cpu.size(1)
batched_hidden_states_cpu = torch.empty(
(e, max_tokens_per_batch, K),
dtype=hidden_states_cpu.dtype,
device="cpu",
)
counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu)
for t_idx, token in enumerate(hidden_states_cpu):
for topk_id in self.topk_ids_cpu[t_idx]:
pos_in_batch = counter_cpu[topk_id]
batched_hidden_states_cpu[topk_id, pos_in_batch] = token
counter_cpu[topk_id] += 1
assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu)
return batched_hidden_states_cpu.to("cuda")
def _gather(
self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor
) -> torch.Tensor:
batched_outputs_cpu = batched_outputs.to("cpu")
gather_outputs_cpu = torch.zeros_like(gather_outputs)
counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32)
md = gather_outputs_cpu.size(0)
for t_idx in range(md):
token = None
for topk_id, topk_weight in zip(
self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx]
):
pos_in_batch = counter_cpu[topk_id]
t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight
if token is None:
token = t
else:
token += t
counter_cpu[topk_id] += 1
assert token is not None
gather_outputs_cpu[t_idx] = token
gather_outputs.copy_(gather_outputs_cpu)
return gather_outputs
def run(
self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any]
) -> torch.Tensor:
assert hidden_states.ndim == 2
assert self.is_valid()
batched_hidden_states = self._scatter(hidden_states)
kwargs = fused_marlin_moe_kwargs | {
"hidden_states": batched_hidden_states,
"expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"),
}
batched_outputs = batched_fused_marlin_moe(**kwargs)
output = torch.zeros_like(hidden_states)
output = self._gather(batched_outputs, output)
return output
kwargs = {
"w1": w1_data.qweight,
"w2": w2_data.qweight,
"bias1": None,
"bias2": None,
"w1_scale": w1_data.scales,
"w2_scale": w2_data.scales,
"gating_output": score,
"global_num_experts": e,
"expert_map": None,
"global_scale1": w1_data.global_scale,
"global_scale2": w2_data.global_scale,
"g_idx1": w1_data.g_idx,
"g_idx2": w2_data.g_idx,
"sort_indices1": w1_data.sort_indices,
"sort_indices2": w2_data.sort_indices,
"w1_zeros": w1_data.zeros,
"w2_zeros": w2_data.zeros,
"quant_type_id": quant_dtype.id,
"is_k_full": True,
}
# Reference
fused_marlin_moe_kwargs = kwargs | {
"hidden_states": a,
"topk_ids": topk_ids,
"topk_weights": topk_weights,
}
ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs)
# Batched
br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights)
if not br.is_valid():
pytest.skip("Cannot represent data in Batched Format.")
marlin_output = br.run(a, kwargs)
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)

View File

@@ -9,6 +9,7 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.platforms import current_platform
@@ -300,3 +301,96 @@ def test_moe_align_block_size_deterministic():
assert torch.equal(results[0][2], results[i][2]), (
"num_tokens should be deterministic"
)
@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512])
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
@pytest.mark.parametrize("block_size", [8, 16, 32, 64])
@pytest.mark.parametrize("simulate_empty_batches", [False, True])
def test_batched_moe_align_block_size(
max_tokens_per_batch: int,
num_experts: int,
block_size: int,
simulate_empty_batches: bool,
):
def ref_outputs(
expert_num_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
E = expert_num_tokens.size(0)
# Round up so each batch can be split to blocks evenly.
Msum = round_up(max_tokens_per_batch, block_size) * E
ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32)
ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32)
ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32)
# Intialize
sentinel = E * max_tokens_per_batch
ref_sorted_ids.fill_(sentinel)
ref_expert_ids.fill_(-1)
# Fill ref_sorted_ids
i = 0
for expert_id, expert_nt in enumerate(expert_num_tokens):
token_offset = expert_id * max_tokens_per_batch
for j in range(expert_nt):
ref_sorted_ids[i] = token_offset + j
i += 1
# round up i to the next block_size
i = round_up(i, block_size)
ref_num_tokens_post_pad[0] = i
# Fill expert_ids
nt_ceil_sum = 0
for expert_id, expert_nt in enumerate(expert_num_tokens):
expert_ids_offset = nt_ceil_sum // block_size
ceil_expert_nt = round_up(int(expert_nt.item()), block_size)
num_blocks = ceil_expert_nt // block_size
for x in range(num_blocks):
ref_expert_ids[expert_ids_offset + x] = expert_id
nt_ceil_sum += ceil_expert_nt
return (
ref_sorted_ids.to("cuda"),
ref_expert_ids.to("cuda"),
ref_num_tokens_post_pad.to("cuda"),
)
# Compute expert_num_tokens
expert_num_tokens = torch.randint(
low=0,
high=max_tokens_per_batch,
size=(num_experts,),
device="cpu",
dtype=torch.int32,
)
if simulate_empty_batches:
# mark half the batches to have 0 tokens
zero_batches = torch.randperm(num_experts)[: num_experts // 2]
expert_num_tokens[zero_batches] = 0
# ref outputs
ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs(
expert_num_tokens
)
# outputs
sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size(
max_tokens_per_batch, block_size, expert_num_tokens.to("cuda")
)
assert ref_sorted_ids.size() == sorted_ids.size(), (
f"{ref_sorted_ids.size()} vs {sorted_ids.size()}"
)
assert ref_expert_ids.size() == expert_ids.size(), (
f"{ref_expert_ids.size()} vs {expert_ids.size()}"
)
assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), (
f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}"
)
torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0)
torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0)
torch.testing.assert_close(
ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0
)

View File

@@ -1789,6 +1789,24 @@ def moe_align_block_size(
)
def batched_moe_align_block_size(
max_tokens_per_batch: int,
block_size: int,
expert_num_tokens: torch.Tensor,
sorted_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
torch.ops._moe_C.batched_moe_align_block_size(
max_tokens_per_batch,
block_size,
expert_num_tokens,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
def moe_wna16_gemm(
input: torch.Tensor,
output: torch.Tensor,

View File

@@ -50,7 +50,31 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP low-latency kernels are compiled only for certain
# specific hidden sizes.
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168]
# NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends
# on it.
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168, 8192]
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int) -> int:
# Round up hidden size to the closest supported hidden size.
_supported_hs = DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES
# Check sorted
num_supported_hs = len(_supported_hs)
assert all(
[
_supported_hs[i] < _supported_hs[i + 1]
for i in range(num_supported_hs - 1)
]
)
for x in _supported_hs:
if x >= hidden_size:
return x
raise ValueError(
f"Hidden Size {hidden_size} is greater than the "
f"maximum supported hidden size {_supported_hs[-1]}"
)
def __init__(
self,

View File

@@ -3,13 +3,16 @@
"""Fused MoE utilities for GPTQ."""
import torch
from typing_extensions import override
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
@@ -21,6 +24,153 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types
def _fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: torch.Tensor | None,
bias2: torch.Tensor | None,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
num_topk: int,
quant_type: ScalarType,
apply_router_weight_on_input: bool,
activation: str,
expert_map: torch.Tensor | None,
block_size_m: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
g_idx2: torch.Tensor | None = None,
sort_indices1: torch.Tensor | None = None,
sort_indices2: torch.Tensor | None = None,
w1_zeros: torch.Tensor | None = None,
w2_zeros: torch.Tensor | None = None,
workspace: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
output: torch.Tensor | None = None,
is_k_full: bool = True,
) -> torch.Tensor:
assert hidden_states.ndim == 2
M, K = hidden_states.size()
N = marlin_moe_intermediate_size(w1, w2)
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
if intermediate_cache13 is None:
intermediate_cache13 = torch.empty(
(M * num_topk * max(2 * N, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if intermediate_cache2 is None:
intermediate_cache2 = torch.empty(
(M * num_topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N))
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
intermediate_cache1,
w1,
bias1,
w1_scale,
global_scale1,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=num_topk,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
)
if activation == "silu":
torch.ops._C.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
if output is None:
output = intermediate_cache3
if expert_map is not None:
output.zero_()
output = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
output,
w2,
bias2,
w2_scale,
global_scale2,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * num_topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
)
return output
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -62,23 +212,27 @@ def fused_marlin_moe(
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (Optional[torch.Tensor]): The output of the gating
- gating_output (torch.Tensor|None): The output of the gating
operation (before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
- g_idx1 (torch.Tensor|None): The first set of act_order indices.
- g_idx2 (torch.Tensor|None): The second set of act_order indices.
- sort_indices1 (torch.Tensor|None): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
- sort_indices2 (torch.Tensor|None): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- w1_zeros (torch.Tensor|None): Optional zero points to be used for w1.
- w2_zeros (torch.Tensor|None): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if inplace:
assert output is None, "Conflicting request"
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
scalar_types.uint4,
@@ -95,15 +249,15 @@ def fused_marlin_moe(
]
num_bits = 4 if quant_type in bit4_scalar_types else 8
M, K = hidden_states.size()
E = w1.size(0)
topk = topk_ids.size(1)
# Check constraints.
if gating_output is not None:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch"
)
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), (
"Hidden size mismatch w2"
)
assert gating_output.size(0) == M, "Number of tokens mismatch"
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
@@ -111,11 +265,6 @@ def fused_marlin_moe(
assert num_bits in [4, 8]
assert topk_weights.dtype == torch.float32
M, K = hidden_states.shape
E = w1.shape[0]
N = marlin_moe_intermediate_size(w1, w2)
topk = topk_ids.shape[1]
# M block size selection logic
# TODO: tune this further for specific models
for block_size_m in [8, 16, 32, 48, 64]:
@@ -128,107 +277,38 @@ def fused_marlin_moe(
topk_ids, block_size_m, global_num_experts, expert_map
)
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
if intermediate_cache2 is None:
intermediate_cache2 = torch.empty(
(M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if intermediate_cache13 is None:
intermediate_cache13 = torch.empty(
(M * topk * max(2 * N, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N))
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
intermediate_cache1,
w1,
bias1,
w1_scale,
global_scale1,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
size_n=2 * N,
size_k=K,
assert activation is not None
moe_output = _fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
bias1=bias1,
bias2=bias2,
w1_scale=w1_scale,
w2_scale=w2_scale,
topk_weights=topk_weights,
num_topk=topk,
quant_type=quant_type,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
expert_map=expert_map,
block_size_m=block_size_m,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=w1_zeros,
w2_zeros=w2_zeros,
workspace=workspace,
intermediate_cache13=intermediate_cache13,
intermediate_cache2=intermediate_cache2,
output=None,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
)
if activation == "silu":
torch.ops._C.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
intermediate_cache3,
w2,
bias2,
w2_scale,
global_scale2,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
).view(-1, topk, K)
if output is None:
@@ -237,16 +317,173 @@ def fused_marlin_moe(
else:
output = torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def batched_fused_marlin_moe(
hidden_states: torch.Tensor,
expert_num_tokens: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: torch.Tensor | None,
bias2: torch.Tensor | None,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: str | None = "silu",
expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
g_idx2: torch.Tensor | None = None,
sort_indices1: torch.Tensor | None = None,
sort_indices2: torch.Tensor | None = None,
w1_zeros: torch.Tensor | None = None,
w2_zeros: torch.Tensor | None = None,
workspace: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: torch.Tensor | None = None,
inplace: bool = False,
) -> torch.Tensor:
"""
This function massages the inputs so the batched hidden_states can be
presented as a 2D contiguous tensor that could be used with
_fused_marlin_moe.
Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately
use `ops.moe_wna16_marlin_gemm` for the gemm operation and
`ops.moe_mna16_marlin_gemm` supports only 2D contiguous hidden_states.
Note that the moe_align_block_size function indicates,
- What rows of the A matrix (hidden_states) to access during the
matmul, via sorted_ids output.
- What expert_id to use for each block matmul, via expert_ids ouptut.
In the batched version, the tokens are already grouped/batched by experts
they subscribe to. Due to this, we can represent the batched hidden_states
tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape,
[B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor
with topk=1 as each token (row in the tensor) subscribes to exactly one
expert_id (which is the batch_id). With the expert_num_tokens tensor, that
indicates how many tokens are actually valid in each batch, the
batched_moe_align_block_size function constructs the sorted_ids and
expert_ids tensors, so only relevant/valid rows of A (hidden_states)
are accessed and are processed with the correct expert_ids.
"""
assert hidden_states.ndim == 3, (
f"hidden states must be batched. e.g. [B, MAX_TOKENS, K]."
f"But got {hidden_states.size()}"
)
if inplace:
assert output is None, "Conflicting request."
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
scalar_types.uint4,
scalar_types.uint8b128,
scalar_types.uint4b8,
scalar_types.float8_e4m3fn,
scalar_types.float4_e2m1f,
]
bit4_scalar_types = [
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.float4_e2m1f,
]
num_bits = 4 if quant_type in bit4_scalar_types else 8
B, BATCH_TOKENS_MAX, K = hidden_states.size()
M = hidden_states.view(-1, K).size(0)
E = w1.size(0)
# Check constraints.
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert expert_num_tokens.size(0) == E
assert B == E, (
"Batch must be as big as number of experts as the tokens"
"are sorted into the batch/expert they belong to"
)
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert num_bits in [4, 8]
# Technically, the tokens are already separated by their expert ids.
# Hidden-States can just be squeezed to have just 2 dimensions,
# [B * MAX_TOKENS, K] and top_k can be interpreted as just 1.
topk = 1
# TODO(varun) : Choose a decent block size like in fused_marlin_moe
block_size_m = 64
sorted_token_ids, expert_ids, num_tokens_post_padded = batched_moe_align_block_size(
max_tokens_per_batch=BATCH_TOKENS_MAX,
block_size=block_size_m,
expert_num_tokens=expert_num_tokens,
)
if output is None and inplace:
output = hidden_states
# TODO (varun): This can be avoided by plumbing the marlin kernel to
# ignore topk_weights when topk_weights_ptr is a nullptr.
topk_weights = torch.ones(
(M, topk), device=hidden_states.device, dtype=torch.float32
)
assert activation is not None
output = _fused_marlin_moe(
hidden_states=hidden_states.view(-1, K),
w1=w1,
w2=w2,
bias1=bias1,
bias2=bias2,
w1_scale=w1_scale,
w2_scale=w2_scale,
topk_weights=topk_weights,
num_topk=topk,
quant_type=quant_type,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
expert_map=expert_map,
block_size_m=block_size_m,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=w1_zeros,
w2_zeros=w2_zeros,
workspace=workspace,
intermediate_cache13=intermediate_cache13,
intermediate_cache2=intermediate_cache2,
output=output.view(-1, K) if output is not None else output,
is_k_full=is_k_full,
)
output = output.view(B, BATCH_TOKENS_MAX, K)
return output
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
@override
def moe_problem_size(
self,
a1: torch.Tensor,
@@ -274,6 +511,11 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
return E, M, N, K, topk
class MarlinExperts(MarlinExpertsBase):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
def supports_expert_map(self) -> bool:
return True
@@ -365,3 +607,90 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache13=workspace2,
intermediate_cache2=workspace13,
)
class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceDelegate()
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
def supports_chunking(self) -> bool:
return False
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert expert_tokens_meta is not None, "Num valid tokens per batch is required"
return batched_fused_marlin_moe(
hidden_states=hidden_states,
expert_num_tokens=expert_tokens_meta.expert_num_tokens,
w1=w1,
w2=w2,
bias1=self.w1_bias,
bias2=self.w2_bias,
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
output=output,
intermediate_cache13=workspace13,
intermediate_cache2=workspace2,
)

View File

@@ -994,6 +994,11 @@ def maybe_roundup_hidden_size(
hidden_size, act_dtype
)
if moe_parallel_config.use_deepep_ll_kernels:
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size
)
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (

View File

@@ -83,3 +83,92 @@ def moe_align_block_size(
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def batched_moe_align_block_size(
max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given num_batches, max_tokens_per_batch, block_size and the number of
valid-tokens in each batch, prepare sorted_token_ids, expert_ids and
num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad
have the same semantics as in moe_align_block_size.
This function is intended to be a drop in replacement for
moe_align_batch_size for the batched case.
Parameters:
- max_tokens_per_batch (int): Number of tokens in each batch (both
valid and invalid).
- block_size (int): block_size to align the data to.
- expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates
the number of valid tokens in batch i.
Returns:
- sorted_token_ids (torch.Tensor): Torch tensor of size
(num_batches * max_tokens_per_batch) indicating the token indices for
that block.
- expert_ids (torch.Tensor): Torch tensor of size
ceil((num_batches * max_tokens_per_batch) / block_size) indicating
what expert to use for each block.
- num_tokens_post_pad (torch.Tensor): Torch tensor of size 1
indicating the number of valid blocks with actual data to
process. This is represented in terms of num tokens.
Example:
Let num_batches=5, max_tokens_per_batch=8, block_size=4, and
expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor
indicates that,
- The first 2 tokens in the 0th batch are valid and the rest 6 are
invalid (i.e. in the 2D hidden_states tensor of shape,
[num_batches * max_tokens_per_batch, K], indices 0, 1 are valid)
- The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10
- 0 tokens in the 2nd batch are valid
- first 6 tokens in the 3rd batch are valid. i.e. indices,
24, 25, 26, 27, 28, 29
- so on ...
In this case,
sorted_token_ids will be [0, 1, 40, 40,
8, 9, 10, 40,
24, 25, 26, 27,
28, 29, 40, 40,
32, 33, 34, 35,
36, 37, 38, 39,
40, 40, 40, 40,
(rest all 40, 40, 40, 40)
...]
Here, 40 represents an invalid index. as there is no token index 40.
The gemm kernel using this sorted_token_ids is expected to skip the
gemm computation when it encounters this invalid index.
expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...]
Here, -1 represents an invalid expert. The gemm kernel using this
expert_ids is expected to skip the gemm computation when it encounters
an expert of id -1.
num_tokens_post_pad will be 24 as sorted_token_ids has valid entries
until 24.
"""
B = expert_num_tokens.size(0)
device = expert_num_tokens.device
# Round up so each batch can be split to blocks evenly.
max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
assert max_num_tokens_padded % block_size == 0
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=device)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device)
ops.batched_moe_align_block_size(
max_tokens_per_batch,
block_size,
expert_num_tokens,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
return sorted_ids, expert_ids, num_tokens_post_pad

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
@@ -797,9 +798,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP"
)
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
assert self.moe_quant_config is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
raise NotImplementedError(
"Incompatible Mxfp4 backend for EP batched experts format"
)
else:
assert self.moe_quant_config is not None
if (