mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Bug] Fix vLLM config is not set error (#29999)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -460,6 +460,7 @@ def cutlass_moe_fp8(
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
parallel_config=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
@@ -537,6 +538,7 @@ def cutlass_moe_fp8(
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
return fn(
|
||||
|
||||
@@ -44,6 +44,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
) -> "FusedMoEModularMethod":
|
||||
parallel_config = getattr(
|
||||
getattr(moe_layer, "vllm_config", None),
|
||||
"parallel_config",
|
||||
None,
|
||||
)
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
FusedMoEModularKernel(
|
||||
@@ -51,6 +56,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
getattr(moe_layer, "shared_experts_stream", None),
|
||||
parallel_config=parallel_config,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import final
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
@@ -716,6 +716,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
shared_experts_stream: torch.cuda.Stream | None = None,
|
||||
parallel_config: ParallelConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
@@ -723,6 +724,14 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.shared_experts = shared_experts
|
||||
self.shared_experts_stream = shared_experts_stream
|
||||
|
||||
# cache whether this worker is using DP+EP
|
||||
if parallel_config is None:
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
self.is_dp_ep = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_formats[0]
|
||||
@@ -811,13 +820,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
is_forward_context_available()
|
||||
and get_forward_context().attn_metadata is None
|
||||
)
|
||||
if is_profile_run and self.fused_experts.supports_chunking():
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
is_dp_ep = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and parallel_config.enable_expert_parallel
|
||||
)
|
||||
if is_dp_ep:
|
||||
if is_profile_run and self.fused_experts.supports_chunking() and self.is_dp_ep:
|
||||
max_workspace_13, max_workspace_2, max_fused_out_shape = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
|
||||
@@ -1287,6 +1287,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
parallel_config=getattr(
|
||||
getattr(layer, "vllm_config", None), "parallel_config", None
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -247,6 +247,11 @@ def flashinfer_cutlass_moe_fp8(
|
||||
assert quant_config is not None
|
||||
|
||||
# Construct modular kernel with block-scale support when requested.
|
||||
parallel_config = getattr(
|
||||
getattr(layer, "vllm_config", None),
|
||||
"parallel_config",
|
||||
None,
|
||||
)
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
@@ -257,6 +262,7 @@ def flashinfer_cutlass_moe_fp8(
|
||||
out_dtype=hidden_states.dtype,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
),
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
|
||||
Reference in New Issue
Block a user