[Bug] Fix vLLM config is not set error (#29999)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-12-05 16:42:12 -05:00
committed by GitHub
parent 77e4472809
commit 7b5575fa7d
5 changed files with 47 additions and 27 deletions

View File

@@ -460,6 +460,7 @@ def cutlass_moe_fp8(
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
parallel_config=None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@@ -537,6 +538,7 @@ def cutlass_moe_fp8(
c_strides2=c_strides2, c_strides2=c_strides2,
quant_config=quant_config, quant_config=quant_config,
), ),
parallel_config=parallel_config,
) )
return fn( return fn(

View File

@@ -44,6 +44,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
) -> "FusedMoEModularMethod": ) -> "FusedMoEModularMethod":
parallel_config = getattr(
getattr(moe_layer, "vllm_config", None),
"parallel_config",
None,
)
return FusedMoEModularMethod( return FusedMoEModularMethod(
old_quant_method, old_quant_method,
FusedMoEModularKernel( FusedMoEModularKernel(
@@ -51,6 +56,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
getattr(moe_layer, "shared_experts_stream", None), getattr(moe_layer, "shared_experts_stream", None),
parallel_config=parallel_config,
), ),
) )

View File

@@ -10,7 +10,7 @@ from typing import final
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import get_current_vllm_config from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
@@ -716,6 +716,7 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None, shared_experts_stream: torch.cuda.Stream | None = None,
parallel_config: ParallelConfig | None = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
@@ -723,6 +724,14 @@ class FusedMoEModularKernel(torch.nn.Module):
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream self.shared_experts_stream = shared_experts_stream
# cache whether this worker is using DP+EP
if parallel_config is None:
parallel_config = get_current_vllm_config().parallel_config
self.is_dp_ep = (
parallel_config.data_parallel_size > 1
and parallel_config.enable_expert_parallel
)
self._post_init_setup() self._post_init_setup()
assert ( assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0] prepare_finalize.activation_format == fused_experts.activation_formats[0]
@@ -811,33 +820,27 @@ class FusedMoEModularKernel(torch.nn.Module):
is_forward_context_available() is_forward_context_available()
and get_forward_context().attn_metadata is None and get_forward_context().attn_metadata is None
) )
if is_profile_run and self.fused_experts.supports_chunking(): if is_profile_run and self.fused_experts.supports_chunking() and self.is_dp_ep:
parallel_config = get_current_vllm_config().parallel_config max_workspace_13, max_workspace_2, max_fused_out_shape = (
is_dp_ep = ( self.fused_experts.workspace_shapes(
parallel_config.data_parallel_size > 1 envs.VLLM_FUSED_MOE_CHUNK_SIZE,
and parallel_config.enable_expert_parallel N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
)
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype
)
buffers.workspace2.get(
max_workspace_2, device=device, dtype=workspace_dtype
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
) )
if is_dp_ep:
max_workspace_13, max_workspace_2, max_fused_out_shape = (
self.fused_experts.workspace_shapes(
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
)
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype
)
buffers.workspace2.get(
max_workspace_2, device=device, dtype=workspace_dtype
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
)
# Get intermediate workspace shapes based off the chunked M size. # Get intermediate workspace shapes based off the chunked M size.
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(

View File

@@ -1287,6 +1287,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2, c_strides2=self.ab_strides1_c_strides2,
parallel_config=getattr(
getattr(layer, "vllm_config", None), "parallel_config", None
),
) )
else: else:

View File

@@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8(
assert quant_config is not None assert quant_config is not None
# Construct modular kernel with block-scale support when requested. # Construct modular kernel with block-scale support when requested.
parallel_config = getattr(
getattr(layer, "vllm_config", None),
"parallel_config",
None,
)
fused_experts = mk.FusedMoEModularKernel( fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize( build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
@@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8(
out_dtype=hidden_states.dtype, out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
), ),
parallel_config=parallel_config,
) )
return fused_experts( return fused_experts(