From 7b5575fa7dcf76ac86ab8d18501b9cc04f74f6bb Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:42:12 -0500 Subject: [PATCH] [Bug] Fix vLLM config is not set error (#29999) Signed-off-by: yewentao256 --- .../layers/fused_moe/cutlass_moe.py | 2 + .../fused_moe/fused_moe_modular_method.py | 6 ++ .../layers/fused_moe/modular_kernel.py | 57 ++++++++++--------- .../compressed_tensors_moe.py | 3 + .../quantization/utils/flashinfer_utils.py | 6 ++ 5 files changed, 47 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 6753a19250b..30144ca5452 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -460,6 +460,7 @@ def cutlass_moe_fp8( expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, + parallel_config=None, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -537,6 +538,7 @@ def cutlass_moe_fp8( c_strides2=c_strides2, quant_config=quant_config, ), + parallel_config=parallel_config, ) return fn( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index c23c41df226..b33e7fd8a02 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -44,6 +44,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): prepare_finalize: FusedMoEPrepareAndFinalize, shared_experts: torch.nn.Module | None, ) -> "FusedMoEModularMethod": + parallel_config = getattr( + getattr(moe_layer, "vllm_config", None), + "parallel_config", + None, + ) return FusedMoEModularMethod( old_quant_method, FusedMoEModularKernel( @@ -51,6 +56,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, getattr(moe_layer, "shared_experts_stream", None), + parallel_config=parallel_config, ), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b2af58cdca8..51d3299e7dd 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,7 +10,7 @@ from typing import final import torch import vllm.envs as envs -from vllm.config import get_current_vllm_config +from vllm.config import ParallelConfig, get_current_vllm_config from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -716,6 +716,7 @@ class FusedMoEModularKernel(torch.nn.Module): fused_experts: FusedMoEPermuteExpertsUnpermute, shared_experts: torch.nn.Module | None = None, shared_experts_stream: torch.cuda.Stream | None = None, + parallel_config: ParallelConfig | None = None, ): super().__init__() self.prepare_finalize = prepare_finalize @@ -723,6 +724,14 @@ class FusedMoEModularKernel(torch.nn.Module): self.shared_experts = shared_experts self.shared_experts_stream = shared_experts_stream + # cache whether this worker is using DP+EP + if parallel_config is None: + parallel_config = get_current_vllm_config().parallel_config + self.is_dp_ep = ( + parallel_config.data_parallel_size > 1 + and parallel_config.enable_expert_parallel + ) + self._post_init_setup() assert ( prepare_finalize.activation_format == fused_experts.activation_formats[0] @@ -811,33 +820,27 @@ class FusedMoEModularKernel(torch.nn.Module): is_forward_context_available() and get_forward_context().attn_metadata is None ) - if is_profile_run and self.fused_experts.supports_chunking(): - parallel_config = get_current_vllm_config().parallel_config - is_dp_ep = ( - parallel_config.data_parallel_size > 1 - and parallel_config.enable_expert_parallel + if is_profile_run and self.fused_experts.supports_chunking() and self.is_dp_ep: + max_workspace_13, max_workspace_2, max_fused_out_shape = ( + self.fused_experts.workspace_shapes( + envs.VLLM_FUSED_MOE_CHUNK_SIZE, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + ) + buffers.workspace13.get( + max_workspace_13, device=device, dtype=workspace_dtype + ) + buffers.workspace2.get( + max_workspace_2, device=device, dtype=workspace_dtype + ) + buffers.fused_out.get( + max_fused_out_shape, device=device, dtype=workspace_dtype ) - if is_dp_ep: - max_workspace_13, max_workspace_2, max_fused_out_shape = ( - self.fused_experts.workspace_shapes( - envs.VLLM_FUSED_MOE_CHUNK_SIZE, - N, - K, - top_k, - global_num_experts, - local_num_experts, - expert_tokens_meta, - ) - ) - buffers.workspace13.get( - max_workspace_13, device=device, dtype=workspace_dtype - ) - buffers.workspace2.get( - max_workspace_2, device=device, dtype=workspace_dtype - ) - buffers.fused_out.get( - max_fused_out_shape, device=device, dtype=workspace_dtype - ) # Get intermediate workspace shapes based off the chunked M size. workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d7fb6d2ca36..8013b29f733 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1287,6 +1287,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + parallel_config=getattr( + getattr(layer, "vllm_config", None), "parallel_config", None + ), ) else: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index eef7a0896c3..00c2720a348 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8( assert quant_config is not None # Construct modular kernel with block-scale support when requested. + parallel_config = getattr( + getattr(layer, "vllm_config", None), + "parallel_config", + None, + ) fused_experts = mk.FusedMoEModularKernel( build_flashinfer_fp8_cutlass_moe_prepare_finalize( moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale @@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8( out_dtype=hidden_states.dtype, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ), + parallel_config=parallel_config, ) return fused_experts(