[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2ed8b6b3d0
commit
fb0571b077
@@ -8,12 +8,77 @@
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "../dispatch_utils.h"
|
||||
#include "core/math.hpp"
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
namespace batched_moe_align_block_size {
|
||||
|
||||
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
|
||||
static constexpr int32_t num_threads = 1024;
|
||||
static constexpr int32_t num_blocks = 1;
|
||||
__global__ void batched_moe_align_block_size_kernel(
|
||||
int32_t const num_batches, int32_t const max_tokens_per_batch,
|
||||
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
|
||||
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
|
||||
int32_t* __restrict__ num_tokens_post_pad) {
|
||||
// TODO(varun): This is a naive implementation. Could be optimized.
|
||||
|
||||
size_t const batch_id = threadIdx.x;
|
||||
size_t const stride = blockDim.x * gridDim.x;
|
||||
int32_t const num_blocks_per_batch =
|
||||
CEILDIV(max_tokens_per_batch, block_size);
|
||||
int32_t const sorted_ids_size =
|
||||
num_blocks_per_batch * num_batches * block_size;
|
||||
int32_t const block_ids_size = sorted_ids_size / block_size;
|
||||
int32_t const SENTINEL =
|
||||
num_batches * max_tokens_per_batch; // To denote invalid entries.
|
||||
// Intialize sorted_ids
|
||||
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
|
||||
sorted_ids[i] = SENTINEL;
|
||||
}
|
||||
// Intialize expert_ids with -1
|
||||
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
|
||||
block_ids[i] = -1;
|
||||
}
|
||||
|
||||
int32_t b_num_tokens = 0;
|
||||
if (batch_id < num_batches) {
|
||||
b_num_tokens = batch_num_tokens[batch_id];
|
||||
}
|
||||
int32_t const ceil_b_num_tokens =
|
||||
CEILDIV(b_num_tokens, block_size) * block_size;
|
||||
|
||||
// Compute prefix sum over token counts per expert
|
||||
using BlockScan = cub::BlockScan<int32_t, 1024>;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
int cumsum_val;
|
||||
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
|
||||
__syncthreads();
|
||||
|
||||
bool const is_last_batch = batch_id == (num_batches - 1);
|
||||
if (is_last_batch) {
|
||||
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
|
||||
}
|
||||
|
||||
if (batch_id < num_batches) {
|
||||
int32_t const batch_offset = batch_id * max_tokens_per_batch;
|
||||
for (size_t i = 0; i < b_num_tokens; ++i) {
|
||||
sorted_ids[cumsum_val + i] = batch_offset + i;
|
||||
}
|
||||
|
||||
int32_t const block_start = cumsum_val / block_size;
|
||||
int32_t const num_blocks = ceil_b_num_tokens / block_size;
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
block_ids[block_start + i] = batch_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace batched_moe_align_block_size
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
});
|
||||
}
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
torch::Tensor const& batch_num_tokens,
|
||||
torch::Tensor sorted_ids,
|
||||
torch::Tensor batch_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
int32_t const B = batch_num_tokens.size(0);
|
||||
int32_t const num_blocks_per_batch =
|
||||
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
|
||||
int32_t const num_blocks = num_blocks_per_batch * B;
|
||||
int64_t const sorted_ids_size = num_blocks * block_size;
|
||||
|
||||
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
|
||||
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
|
||||
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
|
||||
TORCH_CHECK(B <= batched_kernel::num_threads);
|
||||
|
||||
batched_kernel::batched_moe_align_block_size_kernel<<<
|
||||
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
|
||||
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
|
||||
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>());
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
|
||||
@@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
torch::Tensor const& expert_num_tokens,
|
||||
torch::Tensor sorted_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
|
||||
@@ -22,6 +22,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size, but for the batched case.
|
||||
m.def(
|
||||
"batched_moe_align_block_size(int max_tokens_per_batch,"
|
||||
" int block_size, Tensor expert_num_tokens,"
|
||||
" Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("batched_moe_align_block_size", torch::kCUDA,
|
||||
&batched_moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
|
||||
@@ -92,8 +92,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
|
||||
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
|
||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
|
||||
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
|
||||
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
|
||||
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||
| marlin experts | standard,</br>batched | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
||||
@@ -115,6 +115,6 @@ The following table shows "families" of modular kernels that are intended to wor
|
||||
|
||||
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
||||
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
|
||||
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
||||
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts`|
|
||||
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
||||
|
||||
@@ -7,6 +7,8 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -26,7 +28,10 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
batched_fused_marlin_moe,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
@@ -564,6 +569,105 @@ def marlin_moe_generate_valid_test_cases():
|
||||
return cases
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinMoEWeightData:
|
||||
w_ref: torch.Tensor
|
||||
qweight: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
global_scale: torch.Tensor | None
|
||||
g_idx: torch.Tensor | None
|
||||
zeros: torch.Tensor | None
|
||||
sort_indices: torch.Tensor | None
|
||||
marlin_bias: torch.Tensor | None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> "MarlinMoEWeightData":
|
||||
assert w.ndim == 3
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
k = w.shape[-1]
|
||||
|
||||
w_ref_l: list[torch.Tensor] = []
|
||||
qweight_l: list[torch.Tensor] = []
|
||||
scales_l: list[torch.Tensor] = []
|
||||
global_scale_l: list[torch.Tensor] = []
|
||||
zeros_l: list[torch.Tensor] = []
|
||||
g_idx_l: list[torch.Tensor] = []
|
||||
sort_indices_l: list[torch.Tensor] = []
|
||||
bias_l: list[torch.Tensor] = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, qweight, scales, global_scale = (
|
||||
rand_marlin_weight_nvfp4_like(w[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
|
||||
w[i], group_size
|
||||
)
|
||||
global_scale = None
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
if global_scale is not None:
|
||||
global_scale_l.append(global_scale)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
elif has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zeros_l.append(zeros)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
if bias is not None:
|
||||
bias_l.append(marlin_permute_bias(bias[i]))
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweight_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
global_scale = stack_and_dev(global_scale_l) if global_scale_l else None
|
||||
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
||||
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
marlin_bias = stack_and_dev(bias_l) if bias_l else None
|
||||
|
||||
return MarlinMoEWeightData(
|
||||
w_ref=w_ref,
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
global_scale=global_scale,
|
||||
g_idx=g_idx,
|
||||
zeros=zeros,
|
||||
sort_indices=sort_indices,
|
||||
marlin_bias=marlin_bias,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
||||
@@ -584,7 +688,6 @@ def test_fused_marlin_moe(
|
||||
is_k_full: bool,
|
||||
):
|
||||
torch.cuda.manual_seed(0)
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
@@ -600,152 +703,44 @@ def test_fused_marlin_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
global_scale1_l = []
|
||||
zeros1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref1, qweight1, scales1, global_scale1 = (
|
||||
rand_marlin_weight_nvfp4_like(w1[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
|
||||
w1[i], group_size
|
||||
)
|
||||
global_scale1 = None
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
if global_scale1 is not None:
|
||||
global_scale1_l.append(global_scale1)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
elif has_zp:
|
||||
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
global_scale2_l = []
|
||||
zeros2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref2, qweight2, scales2, global_scale2 = (
|
||||
rand_marlin_weight_nvfp4_like(w2[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
|
||||
w2[i], group_size
|
||||
)
|
||||
global_scale2 = None
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
if global_scale2 is not None:
|
||||
global_scale2_l.append(global_scale2)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
elif has_zp:
|
||||
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
w1_data.qweight,
|
||||
w2_data.qweight,
|
||||
None,
|
||||
None,
|
||||
scales1,
|
||||
scales2,
|
||||
w1_data.scales,
|
||||
w2_data.scales,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
global_scale1=w1_data.global_scale,
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
@@ -773,92 +768,52 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
|
||||
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
b_bias1_l = []
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
bias=b_bias1,
|
||||
)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
global_scale1 = None
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
|
||||
|
||||
b_bias2_l = []
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
global_scale2 = None
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
bias=b_bias2,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
marlin_bias1,
|
||||
marlin_bias2,
|
||||
scales1,
|
||||
scales2,
|
||||
w1_data.qweight,
|
||||
w2_data.qweight,
|
||||
w1_data.marlin_bias,
|
||||
w2_data.marlin_bias,
|
||||
w1_data.scales,
|
||||
w2_data.scales,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
global_scale1=w1_data.global_scale,
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
@@ -895,6 +850,41 @@ def test_moe_align_block_size_opcheck():
|
||||
)
|
||||
|
||||
|
||||
def test_batched_moe_align_block_size_opcheck():
|
||||
max_tokens_per_batch = 512
|
||||
num_experts = 4
|
||||
block_size = 16
|
||||
|
||||
expert_num_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_batch,
|
||||
size=(num_experts,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
|
||||
assert max_num_tokens_padded % block_size == 0
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda")
|
||||
|
||||
opcheck(
|
||||
torch.ops._moe_C.batched_moe_align_block_size,
|
||||
(
|
||||
max_tokens_per_batch,
|
||||
block_size,
|
||||
expert_num_tokens,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@@ -979,3 +969,171 @@ def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
|
||||
else:
|
||||
atol = 5e-2
|
||||
torch.testing.assert_close(out, ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32, 64])
|
||||
@pytest.mark.parametrize("n", [128])
|
||||
@pytest.mark.parametrize("k", [128])
|
||||
@pytest.mark.parametrize("e", [8, 12, 16, 32])
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_batched_fused_marlin_moe(
|
||||
m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int
|
||||
):
|
||||
print(
|
||||
f"testing m={m}, n={n}, k={k}, e={e}, "
|
||||
f"topk={topk}, "
|
||||
f"max_tokens_per_batch={max_tokens_per_batch}"
|
||||
)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtype = scalar_types.float4_e2m1f
|
||||
group_size = 32
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None
|
||||
)
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
class BatchedRun:
|
||||
@staticmethod
|
||||
def _make_expert_num_tokens_cpu(
|
||||
e: int, # num_experts
|
||||
topk_ids_cpu: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu")
|
||||
for topk_id in torch.flatten(topk_ids_cpu):
|
||||
expert_num_tokens_cpu[topk_id] += 1
|
||||
return expert_num_tokens_cpu
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_per_batch: int,
|
||||
num_experts: int,
|
||||
_topk_ids: torch.Tensor,
|
||||
_topk_weights: torch.Tensor,
|
||||
):
|
||||
self.max_tokens_per_batch = max_tokens_per_batch
|
||||
self.e = num_experts
|
||||
self.topk_ids_cpu = _topk_ids.to("cpu")
|
||||
self.topk_weights_cpu = _topk_weights.to("cpu")
|
||||
self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu(
|
||||
self.e, self.topk_ids_cpu
|
||||
)
|
||||
|
||||
def is_valid(self):
|
||||
"""
|
||||
Return True only if the input can be represented in a Batched
|
||||
format.
|
||||
"""
|
||||
return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch)
|
||||
|
||||
def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states_cpu = hidden_states.to("cpu")
|
||||
K = hidden_states_cpu.size(1)
|
||||
batched_hidden_states_cpu = torch.empty(
|
||||
(e, max_tokens_per_batch, K),
|
||||
dtype=hidden_states_cpu.dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu)
|
||||
for t_idx, token in enumerate(hidden_states_cpu):
|
||||
for topk_id in self.topk_ids_cpu[t_idx]:
|
||||
pos_in_batch = counter_cpu[topk_id]
|
||||
batched_hidden_states_cpu[topk_id, pos_in_batch] = token
|
||||
counter_cpu[topk_id] += 1
|
||||
assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu)
|
||||
return batched_hidden_states_cpu.to("cuda")
|
||||
|
||||
def _gather(
|
||||
self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batched_outputs_cpu = batched_outputs.to("cpu")
|
||||
gather_outputs_cpu = torch.zeros_like(gather_outputs)
|
||||
|
||||
counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32)
|
||||
md = gather_outputs_cpu.size(0)
|
||||
for t_idx in range(md):
|
||||
token = None
|
||||
for topk_id, topk_weight in zip(
|
||||
self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx]
|
||||
):
|
||||
pos_in_batch = counter_cpu[topk_id]
|
||||
t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight
|
||||
if token is None:
|
||||
token = t
|
||||
else:
|
||||
token += t
|
||||
counter_cpu[topk_id] += 1
|
||||
assert token is not None
|
||||
gather_outputs_cpu[t_idx] = token
|
||||
gather_outputs.copy_(gather_outputs_cpu)
|
||||
return gather_outputs
|
||||
|
||||
def run(
|
||||
self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any]
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 2
|
||||
assert self.is_valid()
|
||||
|
||||
batched_hidden_states = self._scatter(hidden_states)
|
||||
|
||||
kwargs = fused_marlin_moe_kwargs | {
|
||||
"hidden_states": batched_hidden_states,
|
||||
"expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"),
|
||||
}
|
||||
batched_outputs = batched_fused_marlin_moe(**kwargs)
|
||||
|
||||
output = torch.zeros_like(hidden_states)
|
||||
output = self._gather(batched_outputs, output)
|
||||
return output
|
||||
|
||||
kwargs = {
|
||||
"w1": w1_data.qweight,
|
||||
"w2": w2_data.qweight,
|
||||
"bias1": None,
|
||||
"bias2": None,
|
||||
"w1_scale": w1_data.scales,
|
||||
"w2_scale": w2_data.scales,
|
||||
"gating_output": score,
|
||||
"global_num_experts": e,
|
||||
"expert_map": None,
|
||||
"global_scale1": w1_data.global_scale,
|
||||
"global_scale2": w2_data.global_scale,
|
||||
"g_idx1": w1_data.g_idx,
|
||||
"g_idx2": w2_data.g_idx,
|
||||
"sort_indices1": w1_data.sort_indices,
|
||||
"sort_indices2": w2_data.sort_indices,
|
||||
"w1_zeros": w1_data.zeros,
|
||||
"w2_zeros": w2_data.zeros,
|
||||
"quant_type_id": quant_dtype.id,
|
||||
"is_k_full": True,
|
||||
}
|
||||
|
||||
# Reference
|
||||
fused_marlin_moe_kwargs = kwargs | {
|
||||
"hidden_states": a,
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
}
|
||||
ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs)
|
||||
|
||||
# Batched
|
||||
br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights)
|
||||
if not br.is_valid():
|
||||
pytest.skip("Cannot represent data in Batched Format.")
|
||||
marlin_output = br.run(a, kwargs)
|
||||
|
||||
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -300,3 +301,96 @@ def test_moe_align_block_size_deterministic():
|
||||
assert torch.equal(results[0][2], results[i][2]), (
|
||||
"num_tokens should be deterministic"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512])
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("simulate_empty_batches", [False, True])
|
||||
def test_batched_moe_align_block_size(
|
||||
max_tokens_per_batch: int,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
simulate_empty_batches: bool,
|
||||
):
|
||||
def ref_outputs(
|
||||
expert_num_tokens: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
E = expert_num_tokens.size(0)
|
||||
|
||||
# Round up so each batch can be split to blocks evenly.
|
||||
Msum = round_up(max_tokens_per_batch, block_size) * E
|
||||
ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32)
|
||||
ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32)
|
||||
ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32)
|
||||
|
||||
# Intialize
|
||||
sentinel = E * max_tokens_per_batch
|
||||
ref_sorted_ids.fill_(sentinel)
|
||||
ref_expert_ids.fill_(-1)
|
||||
|
||||
# Fill ref_sorted_ids
|
||||
i = 0
|
||||
for expert_id, expert_nt in enumerate(expert_num_tokens):
|
||||
token_offset = expert_id * max_tokens_per_batch
|
||||
for j in range(expert_nt):
|
||||
ref_sorted_ids[i] = token_offset + j
|
||||
i += 1
|
||||
# round up i to the next block_size
|
||||
i = round_up(i, block_size)
|
||||
|
||||
ref_num_tokens_post_pad[0] = i
|
||||
|
||||
# Fill expert_ids
|
||||
nt_ceil_sum = 0
|
||||
for expert_id, expert_nt in enumerate(expert_num_tokens):
|
||||
expert_ids_offset = nt_ceil_sum // block_size
|
||||
ceil_expert_nt = round_up(int(expert_nt.item()), block_size)
|
||||
num_blocks = ceil_expert_nt // block_size
|
||||
for x in range(num_blocks):
|
||||
ref_expert_ids[expert_ids_offset + x] = expert_id
|
||||
nt_ceil_sum += ceil_expert_nt
|
||||
|
||||
return (
|
||||
ref_sorted_ids.to("cuda"),
|
||||
ref_expert_ids.to("cuda"),
|
||||
ref_num_tokens_post_pad.to("cuda"),
|
||||
)
|
||||
|
||||
# Compute expert_num_tokens
|
||||
expert_num_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_batch,
|
||||
size=(num_experts,),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if simulate_empty_batches:
|
||||
# mark half the batches to have 0 tokens
|
||||
zero_batches = torch.randperm(num_experts)[: num_experts // 2]
|
||||
expert_num_tokens[zero_batches] = 0
|
||||
|
||||
# ref outputs
|
||||
ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs(
|
||||
expert_num_tokens
|
||||
)
|
||||
|
||||
# outputs
|
||||
sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size(
|
||||
max_tokens_per_batch, block_size, expert_num_tokens.to("cuda")
|
||||
)
|
||||
|
||||
assert ref_sorted_ids.size() == sorted_ids.size(), (
|
||||
f"{ref_sorted_ids.size()} vs {sorted_ids.size()}"
|
||||
)
|
||||
assert ref_expert_ids.size() == expert_ids.size(), (
|
||||
f"{ref_expert_ids.size()} vs {expert_ids.size()}"
|
||||
)
|
||||
assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), (
|
||||
f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}"
|
||||
)
|
||||
torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0)
|
||||
torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0)
|
||||
torch.testing.assert_close(
|
||||
ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0
|
||||
)
|
||||
|
||||
@@ -1789,6 +1789,24 @@ def moe_align_block_size(
|
||||
)
|
||||
|
||||
|
||||
def batched_moe_align_block_size(
|
||||
max_tokens_per_batch: int,
|
||||
block_size: int,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
sorted_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._moe_C.batched_moe_align_block_size(
|
||||
max_tokens_per_batch,
|
||||
block_size,
|
||||
expert_num_tokens,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
|
||||
|
||||
def moe_wna16_gemm(
|
||||
input: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
|
||||
@@ -50,7 +50,31 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# DeepEP low-latency kernels are compiled only for certain
|
||||
# specific hidden sizes.
|
||||
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168]
|
||||
# NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends
|
||||
# on it.
|
||||
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168, 8192]
|
||||
|
||||
@staticmethod
|
||||
def maybe_roundup_layer_hidden_size(hidden_size: int) -> int:
|
||||
# Round up hidden size to the closest supported hidden size.
|
||||
_supported_hs = DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES
|
||||
# Check sorted
|
||||
num_supported_hs = len(_supported_hs)
|
||||
assert all(
|
||||
[
|
||||
_supported_hs[i] < _supported_hs[i + 1]
|
||||
for i in range(num_supported_hs - 1)
|
||||
]
|
||||
)
|
||||
|
||||
for x in _supported_hs:
|
||||
if x >= hidden_size:
|
||||
return x
|
||||
|
||||
raise ValueError(
|
||||
f"Hidden Size {hidden_size} is greater than the "
|
||||
f"maximum supported hidden size {_supported_hs[-1]}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
"""Fused MoE utilities for GPTQ."""
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
|
||||
@@ -21,6 +24,153 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
def _fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
bias1: torch.Tensor | None,
|
||||
bias2: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_topk: int,
|
||||
quant_type: ScalarType,
|
||||
apply_router_weight_on_input: bool,
|
||||
activation: str,
|
||||
expert_map: torch.Tensor | None,
|
||||
block_size_m: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
output: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 2
|
||||
M, K = hidden_states.size()
|
||||
N = marlin_moe_intermediate_size(w1, w2)
|
||||
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
if intermediate_cache13 is None:
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * num_topk * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if intermediate_cache2 is None:
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * num_topk, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N))
|
||||
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
|
||||
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
|
||||
|
||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
bias1,
|
||||
w1_scale,
|
||||
global_scale1,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=num_topk,
|
||||
mul_topk_weights=apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = intermediate_cache3
|
||||
|
||||
if expert_map is not None:
|
||||
output.zero_()
|
||||
|
||||
output = ops.moe_wna16_marlin_gemm(
|
||||
intermediate_cache2,
|
||||
output,
|
||||
w2,
|
||||
bias2,
|
||||
w2_scale,
|
||||
global_scale2,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=not apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M * num_topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -62,23 +212,27 @@ def fused_marlin_moe(
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (Optional[torch.Tensor]): The output of the gating
|
||||
- gating_output (torch.Tensor|None): The output of the gating
|
||||
operation (before softmax).
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
- g_idx1 (torch.Tensor|None): The first set of act_order indices.
|
||||
- g_idx2 (torch.Tensor|None): The second set of act_order indices.
|
||||
- sort_indices1 (torch.Tensor|None): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
- sort_indices2 (torch.Tensor|None): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- w1_zeros (torch.Tensor|None): Optional zero points to be used for w1.
|
||||
- w2_zeros (torch.Tensor|None): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if inplace:
|
||||
assert output is None, "Conflicting request"
|
||||
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
scalar_types.uint4,
|
||||
@@ -95,15 +249,15 @@ def fused_marlin_moe(
|
||||
]
|
||||
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
||||
|
||||
M, K = hidden_states.size()
|
||||
E = w1.size(0)
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
# Check constraints.
|
||||
if gating_output is not None:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch"
|
||||
)
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), (
|
||||
"Hidden size mismatch w2"
|
||||
)
|
||||
assert gating_output.size(0) == M, "Number of tokens mismatch"
|
||||
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
|
||||
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
@@ -111,11 +265,6 @@ def fused_marlin_moe(
|
||||
assert num_bits in [4, 8]
|
||||
assert topk_weights.dtype == torch.float32
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = marlin_moe_intermediate_size(w1, w2)
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
# M block size selection logic
|
||||
# TODO: tune this further for specific models
|
||||
for block_size_m in [8, 16, 32, 48, 64]:
|
||||
@@ -128,107 +277,38 @@ def fused_marlin_moe(
|
||||
topk_ids, block_size_m, global_num_experts, expert_map
|
||||
)
|
||||
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
if intermediate_cache2 is None:
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if intermediate_cache13 is None:
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N))
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
|
||||
|
||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
bias1,
|
||||
w1_scale,
|
||||
global_scale1,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
assert activation is not None
|
||||
moe_output = _fused_marlin_moe(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
bias1=bias1,
|
||||
bias2=bias2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
num_topk=topk,
|
||||
quant_type=quant_type,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
block_size_m=block_size_m,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=w1_zeros,
|
||||
w2_zeros=w2_zeros,
|
||||
workspace=workspace,
|
||||
intermediate_cache13=intermediate_cache13,
|
||||
intermediate_cache2=intermediate_cache2,
|
||||
output=None,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
bias2,
|
||||
w2_scale,
|
||||
global_scale2,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=not apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
).view(-1, topk, K)
|
||||
|
||||
if output is None:
|
||||
@@ -237,16 +317,173 @@ def fused_marlin_moe(
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
|
||||
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
|
||||
|
||||
|
||||
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def batched_fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
bias1: torch.Tensor | None,
|
||||
bias2: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor | None,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
activation: str | None = "silu",
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
output: torch.Tensor | None = None,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function massages the inputs so the batched hidden_states can be
|
||||
presented as a 2D contiguous tensor that could be used with
|
||||
_fused_marlin_moe.
|
||||
|
||||
Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately
|
||||
use `ops.moe_wna16_marlin_gemm` for the gemm operation and
|
||||
`ops.moe_mna16_marlin_gemm` supports only 2D contiguous hidden_states.
|
||||
Note that the moe_align_block_size function indicates,
|
||||
- What rows of the A matrix (hidden_states) to access during the
|
||||
matmul, via sorted_ids output.
|
||||
- What expert_id to use for each block matmul, via expert_ids ouptut.
|
||||
|
||||
In the batched version, the tokens are already grouped/batched by experts
|
||||
they subscribe to. Due to this, we can represent the batched hidden_states
|
||||
tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape,
|
||||
[B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor
|
||||
with topk=1 as each token (row in the tensor) subscribes to exactly one
|
||||
expert_id (which is the batch_id). With the expert_num_tokens tensor, that
|
||||
indicates how many tokens are actually valid in each batch, the
|
||||
batched_moe_align_block_size function constructs the sorted_ids and
|
||||
expert_ids tensors, so only relevant/valid rows of A (hidden_states)
|
||||
are accessed and are processed with the correct expert_ids.
|
||||
"""
|
||||
|
||||
assert hidden_states.ndim == 3, (
|
||||
f"hidden states must be batched. e.g. [B, MAX_TOKENS, K]."
|
||||
f"But got {hidden_states.size()}"
|
||||
)
|
||||
if inplace:
|
||||
assert output is None, "Conflicting request."
|
||||
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint8b128,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn,
|
||||
scalar_types.float4_e2m1f,
|
||||
]
|
||||
|
||||
bit4_scalar_types = [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.float4_e2m1f,
|
||||
]
|
||||
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
||||
|
||||
B, BATCH_TOKENS_MAX, K = hidden_states.size()
|
||||
M = hidden_states.view(-1, K).size(0)
|
||||
E = w1.size(0)
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert expert_num_tokens.size(0) == E
|
||||
assert B == E, (
|
||||
"Batch must be as big as number of experts as the tokens"
|
||||
"are sorted into the batch/expert they belong to"
|
||||
)
|
||||
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
|
||||
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
# Technically, the tokens are already separated by their expert ids.
|
||||
# Hidden-States can just be squeezed to have just 2 dimensions,
|
||||
# [B * MAX_TOKENS, K] and top_k can be interpreted as just 1.
|
||||
topk = 1
|
||||
|
||||
# TODO(varun) : Choose a decent block size like in fused_marlin_moe
|
||||
block_size_m = 64
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = batched_moe_align_block_size(
|
||||
max_tokens_per_batch=BATCH_TOKENS_MAX,
|
||||
block_size=block_size_m,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
if output is None and inplace:
|
||||
output = hidden_states
|
||||
|
||||
# TODO (varun): This can be avoided by plumbing the marlin kernel to
|
||||
# ignore topk_weights when topk_weights_ptr is a nullptr.
|
||||
topk_weights = torch.ones(
|
||||
(M, topk), device=hidden_states.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
assert activation is not None
|
||||
output = _fused_marlin_moe(
|
||||
hidden_states=hidden_states.view(-1, K),
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
bias1=bias1,
|
||||
bias2=bias2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
num_topk=topk,
|
||||
quant_type=quant_type,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
block_size_m=block_size_m,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=w1_zeros,
|
||||
w2_zeros=w2_zeros,
|
||||
workspace=workspace,
|
||||
intermediate_cache13=intermediate_cache13,
|
||||
intermediate_cache2=intermediate_cache2,
|
||||
output=output.view(-1, K) if output is not None else output,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
output = output.view(B, BATCH_TOKENS_MAX, K)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||
super().__init__(quant_config)
|
||||
|
||||
@override
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@@ -274,6 +511,11 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class MarlinExperts(MarlinExpertsBase):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -365,3 +607,90 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache13=workspace2,
|
||||
intermediate_cache2=workspace13,
|
||||
)
|
||||
|
||||
|
||||
class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
|
||||
workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert expert_tokens_meta is not None, "Num valid tokens per batch is required"
|
||||
return batched_fused_marlin_moe(
|
||||
hidden_states=hidden_states,
|
||||
expert_num_tokens=expert_tokens_meta.expert_num_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
bias1=self.w1_bias,
|
||||
bias2=self.w2_bias,
|
||||
w1_scale=self.w1_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
gating_output=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
output=output,
|
||||
intermediate_cache13=workspace13,
|
||||
intermediate_cache2=workspace2,
|
||||
)
|
||||
|
||||
@@ -994,6 +994,11 @@ def maybe_roundup_hidden_size(
|
||||
hidden_size, act_dtype
|
||||
)
|
||||
|
||||
if moe_parallel_config.use_deepep_ll_kernels:
|
||||
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size
|
||||
)
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
|
||||
@@ -83,3 +83,92 @@ def moe_align_block_size(
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def batched_moe_align_block_size(
|
||||
max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given num_batches, max_tokens_per_batch, block_size and the number of
|
||||
valid-tokens in each batch, prepare sorted_token_ids, expert_ids and
|
||||
num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad
|
||||
have the same semantics as in moe_align_block_size.
|
||||
|
||||
This function is intended to be a drop in replacement for
|
||||
moe_align_batch_size for the batched case.
|
||||
|
||||
Parameters:
|
||||
- max_tokens_per_batch (int): Number of tokens in each batch (both
|
||||
valid and invalid).
|
||||
- block_size (int): block_size to align the data to.
|
||||
- expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates
|
||||
the number of valid tokens in batch i.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids (torch.Tensor): Torch tensor of size
|
||||
(num_batches * max_tokens_per_batch) indicating the token indices for
|
||||
that block.
|
||||
- expert_ids (torch.Tensor): Torch tensor of size
|
||||
ceil((num_batches * max_tokens_per_batch) / block_size) indicating
|
||||
what expert to use for each block.
|
||||
- num_tokens_post_pad (torch.Tensor): Torch tensor of size 1
|
||||
indicating the number of valid blocks with actual data to
|
||||
process. This is represented in terms of num tokens.
|
||||
Example:
|
||||
Let num_batches=5, max_tokens_per_batch=8, block_size=4, and
|
||||
expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor
|
||||
indicates that,
|
||||
- The first 2 tokens in the 0th batch are valid and the rest 6 are
|
||||
invalid (i.e. in the 2D hidden_states tensor of shape,
|
||||
[num_batches * max_tokens_per_batch, K], indices 0, 1 are valid)
|
||||
- The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10
|
||||
- 0 tokens in the 2nd batch are valid
|
||||
- first 6 tokens in the 3rd batch are valid. i.e. indices,
|
||||
24, 25, 26, 27, 28, 29
|
||||
- so on ...
|
||||
|
||||
In this case,
|
||||
sorted_token_ids will be [0, 1, 40, 40,
|
||||
8, 9, 10, 40,
|
||||
24, 25, 26, 27,
|
||||
28, 29, 40, 40,
|
||||
32, 33, 34, 35,
|
||||
36, 37, 38, 39,
|
||||
40, 40, 40, 40,
|
||||
(rest all 40, 40, 40, 40)
|
||||
...]
|
||||
Here, 40 represents an invalid index. as there is no token index 40.
|
||||
The gemm kernel using this sorted_token_ids is expected to skip the
|
||||
gemm computation when it encounters this invalid index.
|
||||
|
||||
expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...]
|
||||
Here, -1 represents an invalid expert. The gemm kernel using this
|
||||
expert_ids is expected to skip the gemm computation when it encounters
|
||||
an expert of id -1.
|
||||
|
||||
num_tokens_post_pad will be 24 as sorted_token_ids has valid entries
|
||||
until 24.
|
||||
"""
|
||||
|
||||
B = expert_num_tokens.size(0)
|
||||
device = expert_num_tokens.device
|
||||
|
||||
# Round up so each batch can be split to blocks evenly.
|
||||
max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size)
|
||||
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
||||
assert max_num_tokens_padded % block_size == 0
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=device)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device)
|
||||
|
||||
ops.batched_moe_align_block_size(
|
||||
max_tokens_per_batch,
|
||||
block_size,
|
||||
expert_num_tokens,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
@@ -797,9 +798,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP"
|
||||
)
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
assert self.moe_quant_config is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Incompatible Mxfp4 backend for EP batched experts format"
|
||||
)
|
||||
else:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user