[NVIDIA] [Perf] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#26714)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -359,8 +359,8 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
|||||||
# Install FlashInfer pre-compiled kernel cache and binaries
|
# Install FlashInfer pre-compiled kernel cache and binaries
|
||||||
# https://docs.flashinfer.ai/installation.html
|
# https://docs.flashinfer.ai/installation.html
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system flashinfer-cubin==0.4.0 \
|
uv pip install --system flashinfer-cubin==0.4.1 \
|
||||||
&& uv pip install --system flashinfer-jit-cache==0.4.0 \
|
&& uv pip install --system flashinfer-jit-cache==0.4.1 \
|
||||||
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||||
&& flashinfer show-config
|
&& flashinfer show-config
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
|
|||||||
|
|
||||||
|
|
||||||
# build flashinfer for torch nightly from source around 10 mins
|
# build flashinfer for torch nightly from source around 10 mins
|
||||||
# release version: v0.4.0
|
# release version: v0.4.1
|
||||||
# todo(elainewy): cache flashinfer build result for faster build
|
# todo(elainewy): cache flashinfer build result for faster build
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
ENV CCACHE_DIR=/root/.cache/ccache
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||||
@@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
|||||||
echo "git clone flashinfer..." \
|
echo "git clone flashinfer..." \
|
||||||
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||||
&& cd flashinfer \
|
&& cd flashinfer \
|
||||||
&& git checkout v0.4.0 \
|
&& git checkout v0.4.1\
|
||||||
&& git submodule update --init --recursive \
|
&& git submodule update --init --recursive \
|
||||||
&& echo "finish git clone flashinfer..." \
|
&& echo "finish git clone flashinfer..." \
|
||||||
&& rm -rf build \
|
&& rm -rf build \
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytor
|
|||||||
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||||
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
|
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
flashinfer-python==0.4.0
|
flashinfer-python==0.4.1
|
||||||
@@ -37,7 +37,7 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
|
|||||||
trtllm_fp4_block_scale_moe,
|
trtllm_fp4_block_scale_moe,
|
||||||
)
|
)
|
||||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -319,7 +319,7 @@ def tg_mxfp4_moe(
|
|||||||
if transpose_optimized:
|
if transpose_optimized:
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
# w13 weight shuffling
|
# w13 weight shuffling
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
_cache_permute_indices,
|
_cache_permute_indices,
|
||||||
w13_weight[i].view(torch.uint8),
|
w13_weight[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -330,7 +330,7 @@ def tg_mxfp4_moe(
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
# w13 scale shuffling
|
# w13 scale shuffling
|
||||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
_cache_permute_indices,
|
_cache_permute_indices,
|
||||||
w13_weight_scale[i].view(torch.uint8),
|
w13_weight_scale[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -344,7 +344,7 @@ def tg_mxfp4_moe(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# w13 bias shuffling
|
# w13 bias shuffling
|
||||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||||
_cache_permute_indices,
|
_cache_permute_indices,
|
||||||
w13_bias[i].clone().reshape(-1, 1),
|
w13_bias[i].clone().reshape(-1, 1),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -356,7 +356,7 @@ def tg_mxfp4_moe(
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
# w2 weight shuffling
|
# w2 weight shuffling
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
_cache_permute_indices,
|
_cache_permute_indices,
|
||||||
w2_weight[i].view(torch.uint8),
|
w2_weight[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -367,7 +367,7 @@ def tg_mxfp4_moe(
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
# w2 scale shuffling
|
# w2 scale shuffling
|
||||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
_cache_permute_indices,
|
_cache_permute_indices,
|
||||||
w2_weight_scale[i].view(torch.uint8),
|
w2_weight_scale[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -381,7 +381,7 @@ def tg_mxfp4_moe(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# w2 bias shuffling
|
# w2 bias shuffling
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
_cache_permute_indices,
|
_cache_permute_indices,
|
||||||
w2_bias[i].clone().reshape(-1, 1),
|
w2_bias[i].clone().reshape(-1, 1),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP,
|
TopKWeightAndReduceNoOP,
|
||||||
)
|
)
|
||||||
from vllm.utils import next_power_of_2
|
|
||||||
|
|
||||||
|
|
||||||
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@@ -65,30 +64,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int):
|
|
||||||
# Number of tokens in the input tensor.
|
|
||||||
num_tokens = x.shape[0]
|
|
||||||
# Factor to account for the imbalance of the experts.
|
|
||||||
# factor equals to the
|
|
||||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
|
||||||
# 1.0 means perfect expert distribution.
|
|
||||||
# > 1.0 means some experts have more tokens than the perfect
|
|
||||||
# distribution.
|
|
||||||
# < 1.0 does not make sense.
|
|
||||||
imbalance_factor = 1.3
|
|
||||||
# Calculate the number of tokens per expert assuming perfect
|
|
||||||
# distribution.
|
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
|
|
||||||
# Apply the imbalance factor.
|
|
||||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
|
||||||
# And pad the number to the next power of 2.
|
|
||||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
|
||||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
|
||||||
# kernel.
|
|
||||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
||||||
|
|
||||||
return tile_tokens_dim
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
@@ -148,9 +123,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"local_expert_offset": local_expert_offset,
|
"local_expert_offset": local_expert_offset,
|
||||||
"local_num_experts": local_num_experts,
|
"local_num_experts": local_num_experts,
|
||||||
"routed_scaling_factor": None,
|
"routed_scaling_factor": None,
|
||||||
"tile_tokens_dim": self._get_tile_tokens_dim(
|
"tile_tokens_dim": None,
|
||||||
x_quant, topk, local_num_experts
|
|
||||||
),
|
|
||||||
"routing_method_type": 1,
|
"routing_method_type": 1,
|
||||||
"do_finalize": True,
|
"do_finalize": True,
|
||||||
"output": output,
|
"output": output,
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import next_power_of_2
|
|
||||||
from vllm.utils.flashinfer import (
|
from vllm.utils.flashinfer import (
|
||||||
flashinfer_scaled_fp4_mm,
|
flashinfer_scaled_fp4_mm,
|
||||||
has_flashinfer,
|
has_flashinfer,
|
||||||
@@ -1125,16 +1124,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
|
|
||||||
# Guess tokens per expert assuming perfect expert distribution first.
|
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
|
||||||
# And pad the number to the next power of 2.
|
|
||||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
|
||||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
|
||||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
||||||
return tile_tokens_dim
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||||
"""
|
"""
|
||||||
MoE Method for FP4 Quantization.
|
MoE Method for FP4 Quantization.
|
||||||
@@ -1332,8 +1321,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
):
|
):
|
||||||
from flashinfer import nvfp4_block_scale_interleave
|
from flashinfer import nvfp4_block_scale_interleave
|
||||||
from flashinfer.fused_moe.core import (
|
from flashinfer.fused_moe.core import (
|
||||||
_maybe_get_cached_w2_permute_indices,
|
|
||||||
_maybe_get_cached_w3_w1_permute_indices,
|
_maybe_get_cached_w3_w1_permute_indices,
|
||||||
|
get_w2_permute_indices_with_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||||
@@ -1394,7 +1383,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
gemm2_weights_fp4[i].view(torch.uint8),
|
gemm2_weights_fp4[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -1405,7 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
|
||||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -1664,9 +1653,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
local_num_experts=layer.local_num_experts,
|
local_num_experts=layer.local_num_experts,
|
||||||
routed_scaling_factor=None,
|
routed_scaling_factor=None,
|
||||||
tile_tokens_dim=_get_tile_tokens_dim(
|
tile_tokens_dim=None,
|
||||||
x.shape[0], top_k, layer.local_num_experts
|
|
||||||
),
|
|
||||||
routing_method_type=routing_method_type,
|
routing_method_type=routing_method_type,
|
||||||
do_finalize=True,
|
do_finalize=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ from vllm.scalar_type import scalar_types
|
|||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
has_triton_kernels,
|
has_triton_kernels,
|
||||||
is_torch_equal_or_newer,
|
is_torch_equal_or_newer,
|
||||||
next_power_of_2,
|
|
||||||
round_up,
|
round_up,
|
||||||
)
|
)
|
||||||
from vllm.utils.flashinfer import has_flashinfer
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
@@ -97,12 +96,6 @@ def get_mxfp4_backend():
|
|||||||
and has_flashinfer()
|
and has_flashinfer()
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||||
):
|
):
|
||||||
logger.info_once(
|
|
||||||
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
|
|
||||||
"for high concurrency throughput workloads consider setting "
|
|
||||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
|
|
||||||
"performance"
|
|
||||||
)
|
|
||||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
elif current_platform.is_device_capability(100) and has_flashinfer():
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@@ -357,7 +350,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||||
):
|
):
|
||||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||||
|
|
||||||
layer.gemm1_alpha = Parameter(
|
layer.gemm1_alpha = Parameter(
|
||||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
@@ -449,7 +442,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||||
for i in range(self.num_experts):
|
for i in range(self.num_experts):
|
||||||
# w13 weight shuffling
|
# w13 weight shuffling
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
w13_weight[i].view(torch.uint8),
|
w13_weight[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -460,7 +453,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
# w13 scale shuffling
|
# w13 scale shuffling
|
||||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
w13_weight_scale[i].view(torch.uint8),
|
w13_weight_scale[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -476,7 +469,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# w13 bias shuffling
|
# w13 bias shuffling
|
||||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
w13_bias[i].clone().reshape(-1, 1),
|
w13_bias[i].clone().reshape(-1, 1),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -488,7 +481,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
# w2 weight shuffling
|
# w2 weight shuffling
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
w2_weight[i].view(torch.uint8),
|
w2_weight[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -499,7 +492,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
# w2 scale shuffling
|
# w2 scale shuffling
|
||||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
w2_weight_scale[i].view(torch.uint8),
|
w2_weight_scale[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -515,7 +508,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# w2 bias shuffling
|
# w2 bias shuffling
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
w2_bias[i].clone().reshape(-1, 1),
|
w2_bias[i].clone().reshape(-1, 1),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -735,30 +728,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||||
|
|
||||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
|
||||||
# Number of tokens in the input tensor.
|
|
||||||
num_tokens = x.shape[0]
|
|
||||||
# Factor to account for the imbalance of the experts.
|
|
||||||
# factor equals to the
|
|
||||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
|
||||||
# - 1.0 means perfect expert distribution.
|
|
||||||
# - > 1.0 means some experts have more
|
|
||||||
# tokens than the perfect distribution.
|
|
||||||
# - < 1.0 does not make sense.
|
|
||||||
imbalance_factor = 1.3
|
|
||||||
# Calculate the number of tokens per expert
|
|
||||||
# assuming perfect distribution.
|
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
|
||||||
# Apply the imbalance factor.
|
|
||||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
|
||||||
# And pad the number to the next power of 2.
|
|
||||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
|
||||||
# Cap to 8-64 tokens per CTA tile
|
|
||||||
# as it's the range supported by the kernel.
|
|
||||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
||||||
|
|
||||||
return tile_tokens_dim
|
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
@@ -1037,7 +1006,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.ep_rank * layer.local_num_experts, # local_expert_offset
|
layer.ep_rank * layer.local_num_experts, # local_expert_offset
|
||||||
self.num_experts, # local num experts
|
self.num_experts, # local num experts
|
||||||
None,
|
None,
|
||||||
self._get_tile_tokens_dim(x, top_k),
|
None,
|
||||||
1 if renormalize else 0, # routing_method_type, renormalize
|
1 if renormalize else 0, # routing_method_type, renormalize
|
||||||
True, # do finalize
|
True, # do finalize
|
||||||
tune_max_num_tokens=self.max_capture_size,
|
tune_max_num_tokens=self.max_capture_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user