[NVIDIA] Add support for cudnn fp4 gemm via flashinfer (#26107)
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
17
vllm/envs.py
17
vllm/envs.py
@@ -191,6 +191,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: str | None = None
|
||||
VLLM_NVFP4_GEMM_BACKEND: str | None = None
|
||||
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
|
||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
@@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# If set, it means we pre-downloaded cubin files and flashinfer will
|
||||
# read the cubin files directly.
|
||||
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
|
||||
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
|
||||
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
|
||||
# vllm cutlass GEMM, marlin GEMM.
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))
|
||||
# Supported options:
|
||||
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
|
||||
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
|
||||
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
|
||||
# - <none>: automatically pick an available backend
|
||||
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
None,
|
||||
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
|
||||
),
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||
@@ -1492,7 +1497,6 @@ def compute_hash() -> str:
|
||||
"VLLM_DISABLED_KERNELS",
|
||||
"VLLM_USE_DEEP_GEMM",
|
||||
"VLLM_USE_DEEP_GEMM_E8M0",
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||
@@ -1524,6 +1528,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
"VLLM_USE_FBGEMM",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
|
||||
@@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
logger.info_once("Using flashinfer-trtllm for FP4")
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
@@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
logger.info_once("Using flashinfer-cutlass for FP4")
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
@@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
@@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
else:
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user