[refactor] CTMoEMethods to use QuantizationArgs (#28871)

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
HDCharles
2025-12-03 06:00:56 -05:00
committed by GitHub
parent 787b84a9fc
commit b294e28db2
2 changed files with 86 additions and 75 deletions

View File

@@ -767,8 +767,10 @@ class CompressedTensorsConfig(QuantizationConfig):
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
return self.target_scheme_map[matched_target]
scheme_dict = self.target_scheme_map[matched_target]
if scheme_dict.get("format") is None:
scheme_dict["format"] = self.quant_format
return scheme_dict
return None

View File

@@ -7,7 +7,11 @@ from enum import Enum
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationStrategy,
)
from torch.nn.parameter import Parameter
import vllm.envs as envs
@@ -142,10 +146,26 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
# are supported + check if the layer is being ignored.
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise
group_size = weight_quant.group_size or -1
valid_format_and_bits = (
weight_quant.num_bits in WNA16_SUPPORTED_BITS
and format == CompressionFormat.pack_quantized.value
)
if not valid_format_and_bits:
raise ValueError(
"For Fused MoE layers, only format: ",
f"{CompressionFormat.pack_quantized.value} ",
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
f"but got format: {CompressionFormat.pack_quantized.value} "
f" and bits: {weight_quant.num_bits}",
)
# Prefer to use the MarlinMoE kernel when it is supported.
if (
not check_moe_marlin_supports_layer(layer, group_size)
@@ -161,12 +181,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
@@ -176,15 +196,15 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
):
return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
else:
raise RuntimeError(
@@ -650,17 +670,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
)
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
@@ -698,11 +720,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant
)
self.use_cutlass = not self.block_quant and (
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
CompressedTensorsConfig._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant
)
or self.is_fp8_w8a8_sm100
)
self.disable_expert_map = False
@@ -1261,16 +1285,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@@ -1414,36 +1436,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
self.layer_name = layer_name
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
assert config.symmetric, "Only symmetric quantization is supported for MoE"
self.weight_quant = weight_quant
self.input_quant = input_quant
assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
)
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
self.strategy = weight_quant.strategy
self.group_size = weight_quant.group_size
self.actorder = weight_quant.actorder
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS
):
raise ValueError(
"For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}",
)
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
self.use_marlin = True
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
def create_weights(
self,
@@ -1812,35 +1825,26 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
self.weight_quant = weight_quant
self.input_quant = input_quant
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
self.strategy = weight_quant.strategy
# channelwise is not supported by this kernel
assert config.strategy == "group"
self.group_size = config.group_size
assert weight_quant.strategy == "group"
self.group_size = weight_quant.group_size
# grouped actorder isn't supported by this kernel
assert config.actorder != "group"
assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS
):
raise ValueError(
"For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}",
)
assert weight_quant.actorder != "group"
assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
)
def create_weights(
self,
@@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.quant_config = quant_config
self.weight_quant = weight_quant
self.input_quant = input_quant
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
aq = self.quant_config.target_scheme_map["Linear"].get("input_activations")
# Must be dynamic per-token activations
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
if (
input_quant.strategy != QuantizationStrategy.TOKEN
or not input_quant.dynamic
):
raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
# Weight can be channel-wise (group_size=None) or group-wise
self.group_size = wq.group_size if (wq.group_size is not None) else -1
if wq.num_bits != 4:
self.group_size = (
weight_quant.group_size if (weight_quant.group_size is not None) else -1
)
if weight_quant.num_bits != 4:
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
# CPU only