# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ctypes import importlib.util import json import logging import os import re import shutil import subprocess import sys import sysconfig from pathlib import Path from shutil import which import torch from packaging.version import Version, parse from setuptools import Extension, setup from setuptools.command.build_ext import build_ext from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME def load_module_from_path(module_name, path): spec = importlib.util.spec_from_file_location(module_name, path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module ROOT_DIR = Path(__file__).parent logger = logging.getLogger(__name__) # cannot import envs directly because it depends on vllm, # which is not installed yet envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py")) VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu": logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS") VLLM_TARGET_DEVICE = "cpu" elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")): logger.warning( "vLLM only supports Linux platform (including WSL) and MacOS." "Building on %s, " "so vLLM may not be able to run correctly", sys.platform, ) VLLM_TARGET_DEVICE = "empty" elif ( sys.platform.startswith("linux") and torch.version.cuda is None and os.getenv("VLLM_TARGET_DEVICE") is None and torch.version.hip is None ): # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu VLLM_TARGET_DEVICE = "cpu" def is_sccache_available() -> bool: return which("sccache") is not None and not bool( int(os.getenv("VLLM_DISABLE_SCCACHE", "0")) ) def is_ccache_available() -> bool: return which("ccache") is not None def is_ninja_available() -> bool: return which("ninja") is not None def is_freethreaded(): return bool(sysconfig.get_config_var("Py_GIL_DISABLED")) class CMakeExtension(Extension): def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. did_config: dict[str, bool] = {} # # Determine number of compilation jobs and optionally nvcc compile threads. # def compute_num_jobs(self): # `num_jobs` is either the value of the MAX_JOBS environment variable # (if defined) or the number of CPUs available. num_jobs = envs.MAX_JOBS if num_jobs is not None: num_jobs = int(num_jobs) logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) else: try: # os.sched_getaffinity() isn't universally available, so fall # back to os.cpu_count() if we get an error here. num_jobs = len(os.sched_getaffinity(0)) except AttributeError: num_jobs = os.cpu_count() nvcc_threads = None if _is_cuda() and get_nvcc_cuda_version() >= Version("11.2"): # `nvcc_threads` is either the value of the NVCC_THREADS # environment variable (if defined) or 1. # when it is set, we reduce `num_jobs` to avoid # overloading the system. nvcc_threads = envs.NVCC_THREADS if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) logger.info( "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads ) else: nvcc_threads = 1 num_jobs = max(1, num_jobs // nvcc_threads) return num_jobs, nvcc_threads # # Perform cmake configuration for a single extension. # def configure(self, ext: CMakeExtension) -> None: # If we've already configured using the CMakeLists.txt for # this extension, exit early. if ext.cmake_lists_dir in cmake_build_ext.did_config: return cmake_build_ext.did_config[ext.cmake_lists_dir] = True # Select the build type. # Note: optimization level + debug info are set by the build type default_cfg = "Debug" if self.debug else "RelWithDebInfo" cfg = envs.CMAKE_BUILD_TYPE or default_cfg cmake_args = [ "-DCMAKE_BUILD_TYPE={}".format(cfg), "-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE), ] verbose = envs.VERBOSE if verbose: cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] if is_sccache_available(): cmake_args += [ "-DCMAKE_C_COMPILER_LAUNCHER=sccache", "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", ] elif is_ccache_available(): cmake_args += [ "-DCMAKE_C_COMPILER_LAUNCHER=ccache", "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", ] # Pass the python executable to cmake so it can find an exact # match. cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)] # Pass the python path to cmake so it can reuse the build dependencies # on subsequent calls to python. cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))] # Override the base directory for FetchContent downloads to $ROOT/.deps # This allows sharing dependencies between profiles, # and plays more nicely with sccache. # To override this, set the FETCHCONTENT_BASE_DIR environment variable. fc_base_dir = os.path.join(ROOT_DIR, ".deps") fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)] # # Setup parallelism and build tool # num_jobs, nvcc_threads = self.compute_num_jobs() if nvcc_threads: cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)] if is_ninja_available(): build_tool = ["-G", "Ninja"] cmake_args += [ "-DCMAKE_JOB_POOL_COMPILE:STRING=compile", "-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs), ] else: # Default build tool to whatever cmake picks. build_tool = [] # Make sure we use the nvcc from CUDA_HOME if _is_cuda(): cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"] elif _is_hip(): cmake_args += [f"-DROCM_PATH={ROCM_HOME}"] other_cmake_args = os.environ.get("CMAKE_ARGS") if other_cmake_args: cmake_args += other_cmake_args.split() subprocess.check_call( ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args], cwd=self.build_temp, ) def build_extensions(self) -> None: # Ensure that CMake is present and working try: subprocess.check_output(["cmake", "--version"]) except OSError as e: raise RuntimeError("Cannot find CMake executable") from e # Create build directory if it does not exist. if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) targets = [] def target_name(s: str) -> str: return s.removeprefix("vllm.").removeprefix("vllm_flash_attn.") # Build all the extensions for ext in self.extensions: self.configure(ext) targets.append(target_name(ext.name)) num_jobs, _ = self.compute_num_jobs() build_args = [ "--build", ".", f"-j={num_jobs}", *[f"--target={name}" for name in targets], ] subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) # Install the libraries for ext in self.extensions: # Install the extension into the proper location outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute() # Skip if the install directory is the same as the build directory if outdir == self.build_temp: continue # CMake appends the extension prefix to the install path, # and outdir already contains that prefix, so we need to remove it. prefix = outdir for _ in range(ext.name.count(".")): prefix = prefix.parent # prefix here should actually be the same for all components install_args = [ "cmake", "--install", ".", "--prefix", prefix, "--component", target_name(ext.name), ] subprocess.check_call(install_args, cwd=self.build_temp) def run(self): # First, run the standard build_ext command to compile the extensions super().run() # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # directory so that they can be included in the editable build import glob files = glob.glob( os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"), recursive=True, ) for file in files: dst_file = os.path.join( "vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1] ) print(f"Copying {file} to {dst_file}") os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) if _is_cuda() or _is_hip(): # copy vllm/third_party/triton_kernels/**/*.py from self.build_lib # to current directory so that they can be included in the editable # build print( f"Copying {self.build_lib}/vllm/third_party/triton_kernels " "to vllm/third_party/triton_kernels" ) shutil.copytree( f"{self.build_lib}/vllm/third_party/triton_kernels", "vllm/third_party/triton_kernels", dirs_exist_ok=True, ) class precompiled_build_ext(build_ext): """Disables extension building when using precompiled binaries.""" def run(self) -> None: return def build_extensions(self) -> None: print("Skipping build_ext: using precompiled extensions.") return class precompiled_wheel_utils: """Extracts libraries and other files from an existing wheel.""" @staticmethod def fetch_metadata_for_variant( commit: str, variant: str | None ) -> tuple[list[dict], str]: """ Fetches metadata for a specific variant of the precompiled wheel. """ variant_dir = f"{variant}/" if variant is not None else "" repo_url = f"https://wheels.vllm.ai/{commit}/{variant_dir}vllm/" meta_url = repo_url + "metadata.json" print(f"Trying to fetch nightly build metadata from {meta_url}") from urllib.request import urlopen with urlopen(meta_url) as resp: # urlopen raises HTTPError on unexpected status code wheels = json.loads(resp.read().decode("utf-8")) return wheels, repo_url @staticmethod def determine_wheel_url() -> tuple[str, str | None]: """ Try to determine the precompiled wheel URL or path to use. The order of preference is: 1. user-specified wheel location (can be either local or remote, via VLLM_PRECOMPILED_WHEEL_LOCATION) 2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo 3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo 4. the default variant from nightly repo If downloading from the nightly repo, the commit can be specified via VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch is used. """ wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) if wheel_location is not None: print(f"Using user-specified precompiled wheel location: {wheel_location}") return wheel_location, None else: import platform arch = platform.machine() # try to fetch the wheel metadata from the nightly wheel repo main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "") variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant) commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower() if not commit or len(commit) != 40: print( f"VLLM_PRECOMPILED_WHEEL_COMMIT not valid: {commit}" ", trying to fetch base commit in main branch" ) commit = precompiled_wheel_utils.get_base_commit_in_main_branch() print(f"Using precompiled wheel commit {commit} with variant {variant}") try_default = False wheels, repo_url, download_filename = None, None, None try: wheels, repo_url = precompiled_wheel_utils.fetch_metadata_for_variant( commit, variant ) except Exception as e: logger.warning( "Failed to fetch precompiled wheel metadata for variant %s: %s", variant, e, ) try_default = True # try outside handler to keep the stacktrace simple if try_default: print("Trying the default variant from remote") wheels, repo_url = precompiled_wheel_utils.fetch_metadata_for_variant( commit, None ) # if this also fails, then we have nothing more to try / cache assert wheels is not None and repo_url is not None, ( "Failed to fetch precompiled wheel metadata" ) # The metadata.json has the following format: # see .buildkite/scripts/generate-nightly-index.py for details """[{ "package_name": "vllm", "version": "0.11.2.dev278+gdbc3d9991", "build_tag": null, "python_tag": "cp38", "abi_tag": "abi3", "platform_tag": "manylinux1_x86_64", "variant": null, "filename": "vllm-0.11.2.dev278+gdbc3d9991-cp38-abi3-manylinux1_x86_64.whl", "path": "../vllm-0.11.2.dev278%2Bgdbc3d9991-cp38-abi3-manylinux1_x86_64.whl" }, ...]""" from urllib.parse import urljoin for wheel in wheels: # TODO: maybe check more compatibility later? (python_tag, abi_tag, etc) if wheel.get("package_name") == "vllm" and arch in wheel.get( "platform_tag", "" ): print(f"Found precompiled wheel metadata: {wheel}") if "path" not in wheel: raise ValueError(f"Wheel metadata missing path: {wheel}") wheel_url = urljoin(repo_url, wheel["path"]) download_filename = wheel.get("filename") print(f"Using precompiled wheel URL: {wheel_url}") break else: raise ValueError( f"No precompiled vllm wheel found for architecture {arch} " f"from repo {repo_url}. All available wheels: {wheels}" ) return wheel_url, download_filename @staticmethod def extract_precompiled_and_patch_package( wheel_url_or_path: str, download_filename: str | None ) -> dict: import tempfile import zipfile temp_dir = None try: if not os.path.isfile(wheel_url_or_path): # use provided filename first, then derive from URL wheel_filename = download_filename or wheel_url_or_path.split("/")[-1] temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") wheel_path = os.path.join(temp_dir, wheel_filename) print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}") from urllib.request import urlretrieve urlretrieve(wheel_url_or_path, filename=wheel_path) else: wheel_path = wheel_url_or_path print(f"Using existing wheel at {wheel_path}") package_data_patch = {} with zipfile.ZipFile(wheel_path) as wheel: files_to_copy = [ "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", "vllm/_flashmla_extension_C.abi3.so", "vllm/_sparse_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/cumem_allocator.abi3.so", ] flash_attn_regex = re.compile( r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" ) triton_kernels_regex = re.compile( r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" ) file_members = list( filter(lambda x: x.filename in files_to_copy, wheel.filelist) ) file_members += list( filter(lambda x: flash_attn_regex.match(x.filename), wheel.filelist) ) file_members += list( filter( lambda x: triton_kernels_regex.match(x.filename), wheel.filelist ) ) for file in file_members: print(f"[extract] {file.filename}") target_path = os.path.join(".", file.filename) os.makedirs(os.path.dirname(target_path), exist_ok=True) with ( wheel.open(file.filename) as src, open(target_path, "wb") as dst, ): shutil.copyfileobj(src, dst) pkg = os.path.dirname(file.filename).replace("/", ".") package_data_patch.setdefault(pkg, []).append( os.path.basename(file.filename) ) return package_data_patch finally: if temp_dir is not None: print(f"Removing temporary directory {temp_dir}") shutil.rmtree(temp_dir) @staticmethod def get_base_commit_in_main_branch() -> str: try: # Get the latest commit hash of the upstream main branch. resp_json = subprocess.check_output( [ "curl", "-s", "https://api.github.com/repos/vllm-project/vllm/commits/main", ] ).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] print(f"Upstream main branch latest commit: {upstream_main_commit}") # In Docker build context, .git may be immutable or missing. if envs.VLLM_DOCKER_BUILD_CONTEXT: return upstream_main_commit # Check if the upstream_main_commit exists in the local repo try: subprocess.check_output( ["git", "cat-file", "-e", f"{upstream_main_commit}"] ) except subprocess.CalledProcessError: # If not present, fetch it from the remote repository. # Note that this does not update any local branches, # but ensures that this commit ref and its history are # available in our local repo. subprocess.check_call( ["git", "fetch", "https://github.com/vllm-project/vllm", "main"] ) # Then get the commit hash of the current branch that is the same as # the upstream main commit. current_branch = ( subprocess.check_output(["git", "branch", "--show-current"]) .decode("utf-8") .strip() ) base_commit = ( subprocess.check_output( ["git", "merge-base", f"{upstream_main_commit}", current_branch] ) .decode("utf-8") .strip() ) return base_commit except ValueError as err: raise ValueError(err) from None except Exception as err: logger.warning( "Failed to get the base commit in the main branch. " "Using the nightly wheel. The libraries in this " "wheel may not be compatible with your dev branch: %s", err, ) return "nightly" def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu() def _is_hip() -> bool: return ( VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm" ) and torch.version.hip is not None def _is_tpu() -> bool: return VLLM_TARGET_DEVICE == "tpu" def _is_cpu() -> bool: return VLLM_TARGET_DEVICE == "cpu" def _is_xpu() -> bool: return VLLM_TARGET_DEVICE == "xpu" def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() def get_rocm_version(): # Get the Rocm version from the ROCM_HOME/bin/librocm-core.so # see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21 try: librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so" if not librocm_core_file.is_file(): return None librocm_core = ctypes.CDLL(librocm_core_file) VerErrors = ctypes.c_uint32 get_rocm_core_version = librocm_core.getROCmVersion get_rocm_core_version.restype = VerErrors get_rocm_core_version.argtypes = [ ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32), ] major = ctypes.c_uint32() minor = ctypes.c_uint32() patch = ctypes.c_uint32() if ( get_rocm_core_version( ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch) ) == 0 ): return f"{major.value}.{minor.value}.{patch.value}" return None except Exception: return None def get_nvcc_cuda_version() -> Version: """Get the CUDA version from nvcc. Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ assert CUDA_HOME is not None, "CUDA_HOME is not set" nvcc_output = subprocess.check_output( [CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True ) output = nvcc_output.split() release_idx = output.index("release") + 1 nvcc_cuda_version = parse(output[release_idx].split(",")[0]) return nvcc_cuda_version def get_vllm_version() -> str: # Allow overriding the version. This is useful to build platform-specific # wheels (e.g. CPU, TPU) without modifying the source. if env_version := os.getenv("VLLM_VERSION_OVERRIDE"): print(f"Overriding VLLM version with {env_version} from VLLM_VERSION_OVERRIDE") os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = env_version return get_version(write_to="vllm/_version.py") version = get_version(write_to="vllm/_version.py") sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): if envs.VLLM_TARGET_DEVICE == "empty": version += f"{sep}empty" elif _is_cuda(): if envs.VLLM_USE_PRECOMPILED and not envs.VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: version += f"{sep}precompiled" else: cuda_version = str(get_nvcc_cuda_version()) if cuda_version != envs.VLLM_MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] # skip this for source tarball, required for pypi if "sdist" not in sys.argv: version += f"{sep}cu{cuda_version_str}" elif _is_hip(): # Get the Rocm Version rocm_version = get_rocm_version() or torch.version.hip if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION: version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): if envs.VLLM_TARGET_DEVICE == "cpu": version += f"{sep}cpu" elif _is_xpu(): version += f"{sep}xpu" else: raise RuntimeError("Unknown runtime environment") return version def get_requirements() -> list[str]: """Get Python package dependencies from requirements.txt.""" requirements_dir = ROOT_DIR / "requirements" def _read_requirements(filename: str) -> list[str]: with open(requirements_dir / filename) as f: requirements = f.read().strip().split("\n") resolved_requirements = [] for line in requirements: if line.startswith("-r "): resolved_requirements += _read_requirements(line.split()[1]) elif ( not line.startswith("--") and not line.startswith("#") and line.strip() != "" ): resolved_requirements.append(line) return resolved_requirements if _no_device(): requirements = _read_requirements("common.txt") elif _is_cuda(): requirements = _read_requirements("cuda.txt") cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: if "vllm-flash-attn" in req and cuda_major != "12": # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("rocm.txt") elif _is_tpu(): requirements = _read_requirements("tpu.txt") elif _is_cpu(): requirements = _read_requirements("cpu.txt") elif _is_xpu(): requirements = _read_requirements("xpu.txt") else: raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.") return requirements ext_modules = [] if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) # Optional since this doesn't get built (produce an .so file). This is just # copying the relevant .py files from the source repository. ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True)) if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): # FA3 requires CUDA 12.3 or later ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) # Optional since this doesn't get built (produce an .so file) when # not targeting a hopper system ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append( CMakeExtension(name="vllm._flashmla_extension_C", optional=True) ) if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) package_data = { "vllm": [ "py.typed", "model_executor/layers/fused_moe/configs/*.json", "model_executor/layers/quantization/utils/configs/*.json", ] } # If using precompiled, extract and patch package_data (in advance of setup) if envs.VLLM_USE_PRECOMPILED: wheel_url, download_filename = precompiled_wheel_utils.determine_wheel_url() patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( wheel_url, download_filename ) for pkg, files in patch.items(): package_data.setdefault(pkg, []).extend(files) if _no_device(): ext_modules = [] if not ext_modules: cmdclass = {} else: cmdclass = { "build_ext": precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext } setup( # static metadata should rather go in pyproject.toml version=get_vllm_version(), ext_modules=ext_modules, install_requires=get_requirements(), extras_require={ "bench": ["pandas", "matplotlib", "seaborn", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], "runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"], "audio": [ "librosa", "soundfile", "mistral_common[audio]", ], # Required for audio processing "video": [], # Kept for backwards compatibility "flashinfer": [], # Kept for backwards compatibility # Optional deps for AMD FP4 quantization support "petit-kernel": ["petit-kernel"], }, cmdclass=cmdclass, package_data=package_data, )