diff --git a/.gitignore b/.gitignore index 50070d7898f..7cda8647866 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* +# OpenAI triton kernels copied from source +vllm/third_party/triton_kernels/* + # triton jit .triton diff --git a/CMakeLists.txt b/CMakeLists.txt index c1c7478b9f3..ae8e6175443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1030,6 +1030,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) endif() +# For CUDA and HIP builds also build the triton_kernels external package. +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") + include(cmake/external_projects/triton_kernels.cmake) +endif() + # For CUDA we also build and ship some external projects. if (VLLM_GPU_LANG STREQUAL "CUDA") include(cmake/external_projects/flashmla.cmake) diff --git a/cmake/external_projects/triton_kernels.cmake b/cmake/external_projects/triton_kernels.cmake new file mode 100644 index 00000000000..d35ad123dd9 --- /dev/null +++ b/cmake/external_projects/triton_kernels.cmake @@ -0,0 +1,53 @@ +# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels + +set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0") + +# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to +# be directly set to the triton_kernels python directory. +if (DEFINED ENV{TRITON_KERNELS_SRC_DIR}) + message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}") + FetchContent_Declare( + triton_kernels + SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR} + ) + +else() + set(TRITON_GIT "https://github.com/triton-lang/triton.git") + message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}") + FetchContent_Declare( + triton_kernels + # TODO (varun) : Fetch just the triton_kernels directory from Triton + GIT_REPOSITORY https://github.com/triton-lang/triton.git + GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG} + GIT_PROGRESS TRUE + SOURCE_SUBDIR python/triton_kernels/triton_kernels + ) +endif() + +# Fetch content +FetchContent_MakeAvailable(triton_kernels) + +if (NOT triton_kernels_SOURCE_DIR) + message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR") +endif() + +if (DEFINED ENV{TRITON_KERNELS_SRC_DIR}) + set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/") +else() + set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/") +endif() + +message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}") + +add_custom_target(triton_kernels) + +# Ensure the vllm/third_party directory exists before installation +install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")") + +## Copy .py files to install directory. +install(DIRECTORY + ${TRITON_KERNELS_PYTHON_DIR} + DESTINATION + vllm/third_party/triton_kernels/ + COMPONENT triton_kernels + FILES_MATCHING PATTERN "*.py") diff --git a/setup.py b/setup.py index e9b36e2a2e0..5591bcb1324 100644 --- a/setup.py +++ b/setup.py @@ -299,6 +299,20 @@ class cmake_build_ext(build_ext): os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) + if _is_cuda() or _is_hip(): + # copy vllm/third_party/triton_kernels/**/*.py from self.build_lib + # to current directory so that they can be included in the editable + # build + print( + f"Copying {self.build_lib}/vllm/third_party/triton_kernels " + "to vllm/third_party/triton_kernels" + ) + shutil.copytree( + f"{self.build_lib}/vllm/third_party/triton_kernels", + "vllm/third_party/triton_kernels", + dirs_exist_ok=True, + ) + class precompiled_build_ext(build_ext): """Disables extension building when using precompiled binaries.""" @@ -633,6 +647,9 @@ ext_modules = [] if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) + # Optional since this doesn't get built (produce an .so file). This is just + # copying the relevant .py files from the source repository. + ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True)) if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 34a31bcf6a7..cbc46810a26 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -8,6 +8,7 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import triton +from vllm.utils.import_utils import has_triton_kernels from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) @@ -15,6 +16,7 @@ logger = init_logger(__name__) def _swizzle_mxfp4(quant_tensor, scale, num_warps): """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" + assert has_triton_kernels() import triton_kernels.matmul_ogs_details.opt_flags as opt_flags from triton_kernels.numerics import InFlexData from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index f01d2c7a6a3..ff0f0350fd9 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -18,6 +18,10 @@ from typing import Any import regex as re from typing_extensions import Never +from vllm.logger import init_logger + +logger = init_logger(__name__) + # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes @@ -62,6 +66,35 @@ def import_pynvml(): return pynvml +@cache +def import_triton_kernels(): + """ + For convenience, prioritize triton_kernels that is available in + `site-packages`. Use `vllm.third_party.triton_kernels` as a fall-back. + """ + if _has_module("triton_kernels"): + import triton_kernels + + logger.debug_once( + f"Loading module triton_kernels from {triton_kernels.__file__}.", + scope="local", + ) + elif _has_module("vllm.third_party.triton_kernels"): + import vllm.third_party.triton_kernels as triton_kernels + + logger.debug_once( + f"Loading module triton_kernels from {triton_kernels.__file__}.", + scope="local", + ) + sys.modules["triton_kernels"] = triton_kernels + else: + logger.info_once( + "triton_kernels unavailable in this build. " + "Please consider installing triton_kernels from " + "https://github.com/triton-lang/triton/tree/main/python/triton_kernels" + ) + + def import_from_path(module_name: str, file_path: str | os.PathLike): """ Import a Python file according to its file path. @@ -397,7 +430,12 @@ def has_deep_gemm() -> bool: def has_triton_kernels() -> bool: """Whether the optional `triton_kernels` package is available.""" - return _has_module("triton_kernels") + is_available = _has_module("triton_kernels") or _has_module( + "vllm.third_party.triton_kernels" + ) + if is_available: + import_triton_kernels() + return is_available def has_tilelang() -> bool: