[Core][AMD] Migrate fully transparent sleep mode to ROCm platform (#12695)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟
2025-11-13 01:24:12 +02:00
committed by GitHub
parent 10f01d5a3a
commit 4ca5cd5740
11 changed files with 582 additions and 31 deletions

View File

@@ -39,6 +39,13 @@ set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151")
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
if(NOT DEFINED ROCM_PATH)
set(ROCM_PATH "/opt/rocm" CACHE PATH "ROCm installation prefix")
else()
set(ROCM_PATH ${ROCM_PATH} CACHE PATH "ROCm installation prefix" FORCE)
endif()
#
# Supported/expected torch versions for CUDA/ROCm.
#
@@ -237,10 +244,27 @@ set_gencode_flags_for_srcs(
SRCS "${VLLM_CUMEM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
if(VLLM_GPU_LANG STREQUAL "CUDA")
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling cumem allocator extension.")
# link against cuda driver library
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
if(VLLM_GPU_LANG STREQUAL "CUDA")
# link against cuda driver library
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
else()
# link against rocm driver library. Prefer an absolute path to
# libamdhip64.so inside ${ROCM_PATH}/lib if available, otherwise fall
# back to linking by name "amdhip64".
find_library(AMDHIP64_LIB
NAMES amdhip64 libamdhip64.so
PATHS ${ROCM_PATH}/lib
NO_DEFAULT_PATH)
if(AMDHIP64_LIB)
message(STATUS "Found libamdhip64 at ${AMDHIP64_LIB}")
list(APPEND CUMEM_LIBS ${AMDHIP64_LIB})
else()
message(WARNING "libamdhip64 not found in ${ROCM_PATH}/lib; falling back to linking 'amdhip64' by name")
list(APPEND CUMEM_LIBS amdhip64)
endif()
endif()
define_extension_target(
cumem_allocator
DESTINATION vllm

View File

@@ -3,14 +3,58 @@
// need to be unsigned long long
#include <iostream>
#include "cumem_allocator_compat.h"
#ifndef USE_ROCM
static const char* PYARGS_PARSE = "KKKK";
#else
#include <cstdlib>
#include <cerrno>
#include <climits>
// Default chunk size 256MB for ROCm. Can be overridden at runtime by the
// environment variable VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE, specified in megabytes
// (MB). The env value is parsed with strtoull as an integer number of MB
// (decimal or 0x hex). The parsed MB value is converted to bytes. If
// parsing fails, the value is 0, or the multiplication would overflow,
// the default (256MB) is used.
static const unsigned long long DEFAULT_MEMCREATE_CHUNK_SIZE =
(256ULL * 1024ULL * 1024ULL);
static unsigned long long get_memcreate_chunk_size() {
const char* env = getenv("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE");
if (!env) return DEFAULT_MEMCREATE_CHUNK_SIZE;
char* endptr = nullptr;
errno = 0;
unsigned long long val_mb = strtoull(env, &endptr, 0);
if (endptr == env || errno != 0) {
// parsing failed, fallback to default
return DEFAULT_MEMCREATE_CHUNK_SIZE;
}
if (val_mb == 0) return DEFAULT_MEMCREATE_CHUNK_SIZE;
const unsigned long long MB = 1024ULL * 1024ULL;
// guard against overflow when converting MB -> bytes
if (val_mb > (ULLONG_MAX / MB)) {
return DEFAULT_MEMCREATE_CHUNK_SIZE;
}
return val_mb * MB;
}
static inline unsigned long long my_min(unsigned long long a,
unsigned long long b) {
return a < b ? a : b;
}
static const char* PYARGS_PARSE = "KKKO";
#endif
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
char error_msg[10240]; // 10KB buffer to store error messages
CUresult no_error = CUresult(0);
@@ -49,7 +93,12 @@ void ensure_context(unsigned long long device) {
}
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle) {
#else
CUmemGenericAllocationHandle** p_memHandle,
unsigned long long* chunk_sizes, size_t num_chunks) {
#endif
ensure_context(device);
// Define memory allocation properties
CUmemAllocationProp prop = {};
@@ -58,6 +107,7 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
if (error_code != 0) {
@@ -67,6 +117,39 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
if (error_code != 0) {
return;
}
#else
for (auto i = 0; i < num_chunks; ++i) {
CUDA_CHECK(cuMemCreate(p_memHandle[i], chunk_sizes[i], &prop, 0));
if (error_code != 0) {
// Clean up previously created handles
for (auto j = 0; j < i; ++j) {
cuMemRelease(*(p_memHandle[j]));
}
return;
}
}
unsigned long long allocated_size = 0;
for (auto i = 0; i < num_chunks; ++i) {
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
CUDA_CHECK(cuMemMap(map_addr, chunk_sizes[i], 0, *(p_memHandle[i]), 0));
if (error_code != 0) {
// unmap previously mapped chunks
unsigned long long unmapped_size = 0;
for (auto j = 0; j < i; ++j) {
void* unmap_addr = (void*)((uintptr_t)d_mem + unmapped_size);
cuMemUnmap(unmap_addr, chunk_sizes[j]);
unmapped_size += chunk_sizes[j];
}
// release all created handles
for (auto j = 0; j < num_chunks; ++j) {
cuMemRelease(*(p_memHandle[j]));
}
return;
}
allocated_size += chunk_sizes[i];
}
#endif
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
@@ -82,10 +165,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
void unmap_and_release(unsigned long long device, ssize_t size,
CUdeviceptr d_mem,
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle) {
#else
CUmemGenericAllocationHandle** p_memHandle,
unsigned long long* chunk_sizes, size_t num_chunks) {
#endif
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
#ifndef USE_ROCM
CUDA_CHECK(cuMemUnmap(d_mem, size));
if (error_code != 0) {
return;
@@ -94,6 +183,30 @@ void unmap_and_release(unsigned long long device, ssize_t size,
if (error_code != 0) {
return;
}
#else
unsigned long long allocated_size = 0;
CUresult first_error = no_error;
for (auto i = 0; i < num_chunks; ++i) {
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
CUresult status = cuMemUnmap(map_addr, chunk_sizes[i]);
if (status != no_error && first_error == no_error) {
first_error = status;
}
allocated_size += chunk_sizes[i];
}
for (auto i = 0; i < num_chunks; ++i) {
CUresult status = cuMemRelease(*(p_memHandle[i]));
if (status != no_error && first_error == no_error) {
first_error = status;
}
}
if (first_error != no_error) {
CUDA_CHECK(first_error);
}
#endif
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
@@ -120,6 +233,36 @@ PyObject* create_tuple_from_c_integers(unsigned long long a,
return tuple; // Return the created tuple
}
PyObject* create_tuple_from_c_mixed(unsigned long long a, unsigned long long b,
unsigned long long c,
CUmemGenericAllocationHandle** vec,
unsigned long long* chunk_sizes,
size_t num_chunks) {
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL;
}
// PyObject* list = PyList_New(vec.size());
PyObject* list = PyList_New(num_chunks);
for (auto i = 0; i < num_chunks; ++i) {
PyObject* addr_size_pair = PyTuple_New(2);
PyObject* addr = PyLong_FromUnsignedLongLong((unsigned long long)(vec[i]));
PyObject* size =
PyLong_FromUnsignedLongLong((unsigned long long)(chunk_sizes[i]));
PyTuple_SetItem(addr_size_pair, 0, addr);
PyTuple_SetItem(addr_size_pair, 1, size);
PyList_SetItem(list, i, addr_size_pair);
}
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, list);
return tuple;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
@@ -147,14 +290,55 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
CUdeviceptr d_mem;
#ifndef USE_ROCM
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
if (error_code != 0) {
return nullptr;
}
#else
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, granularity, 0, 0));
if (error_code != 0) {
return nullptr;
}
#endif
#ifndef USE_ROCM
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
#else
// Make sure chunk size is aligned with hardware granularity. The base
// chunk size can be configured via environment variable
// ``VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE``; otherwise
// DEFAULT_MEMCREATE_CHUNK_SIZE is used.
size_t base_chunk = (size_t)get_memcreate_chunk_size();
size_t aligned_chunk_size =
((base_chunk + granularity - 1) / granularity) * granularity;
size_t num_chunks =
(alignedSize + aligned_chunk_size - 1) / aligned_chunk_size;
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
for (auto i = 0; i < num_chunks; ++i) {
p_memHandle[i] = (CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
if (p_memHandle[i] == nullptr) {
std::cerr << "ERROR: malloc failed for p_memHandle[" << i << "].\n";
for (auto j = 0; j < i; ++j) {
free(p_memHandle[j]);
}
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
chunk_sizes[i] = (unsigned long long)my_min(
(unsigned long long)(alignedSize - i * aligned_chunk_size),
(unsigned long long)aligned_chunk_size);
}
#endif
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
@@ -164,9 +348,15 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
#ifndef USE_ROCM
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
#else
PyObject* arg_tuple = create_tuple_from_c_mixed(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, p_memHandle, chunk_sizes, num_chunks);
#endif
// Call g_python_malloc_callback
PyObject* py_result =
@@ -182,7 +372,27 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
PyGILState_Release(gstate);
// do the final mapping
#ifndef USE_ROCM
create_and_map(device, alignedSize, d_mem, p_memHandle);
#else
create_and_map(device, alignedSize, d_mem, p_memHandle, chunk_sizes,
num_chunks);
free(chunk_sizes);
#endif
if (error_code != 0) {
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, alignedSize));
#ifndef USE_ROCM
free(p_memHandle);
#else
for (size_t i = 0; i < num_chunks; ++i) {
free(p_memHandle[i]);
}
free(p_memHandle);
#endif
return nullptr;
}
return (void*)d_mem;
}
@@ -206,36 +416,96 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
Py_XDECREF(py_result);
Py_XDECREF(py_ptr);
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
if (!PyArg_ParseTuple(py_result, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
Py_XDECREF(py_result);
Py_XDECREF(py_ptr);
return;
}
// For ROCm, copy the Python list of (addr,size) pairs into C arrays while
// holding the GIL. Then release the GIL and call the unmap/release helper
// using the copied arrays. This avoids calling PyList_* APIs without the
// GIL (which is undefined behavior and can crash when called from other
// threads).
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
#ifdef USE_ROCM
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
std::cerr << "ERROR: malloc failed for p_memHandle in my_free."
<< std::endl;
return;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
std::cerr << "ERROR: malloc failed for chunk_sizes in my_free."
<< std::endl;
return;
}
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
}
// Drop temporary Python refs, then release the GIL before calling into
// non-Python APIs.
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
// recv_size == size
// recv_device == device
unmap_and_release(device, size, d_mem, p_memHandle, chunk_sizes, num_chunks);
#else
// Non-ROCm path: simple integer handle already extracted; drop temporary
// Python refs while still holding the GIL, then release it.
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
// Free memory
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
#endif
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, size));
if (error_code != 0) {
return;
#ifndef USE_ROCM
free(p_memHandle);
#else
for (auto i = 0; i < num_chunks; ++i) {
free(p_memHandle[i]);
}
free(p_memHandle);
free(chunk_sizes);
#endif
}
// ---------------------------------------------------------------------------
@@ -271,19 +541,87 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
#else
if (!PyList_Check(recv_p_memHandle)) {
PyErr_SetString(PyExc_TypeError,
"Expected a list for the 4th argument on ROCm");
return nullptr;
}
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
if (num_chunks < 0) {
return nullptr; // PyList_Size sets an exception on error.
}
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
return nullptr;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
return nullptr;
}
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
if (item == nullptr || !PyTuple_Check(item) || PyTuple_Size(item) != 2) {
free(p_memHandle);
free(chunk_sizes);
PyErr_SetString(
PyExc_TypeError,
"List items must be tuples of size 2 (handle_addr, size)");
return nullptr;
}
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
if (addr_py == nullptr || size_py == nullptr) {
free(p_memHandle);
free(chunk_sizes);
return nullptr; // PyTuple_GetItem sets an exception
}
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
if (PyErr_Occurred()) {
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
if (PyErr_Occurred()) {
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
}
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
num_chunks);
free(p_memHandle);
free(chunk_sizes);
#endif
if (error_code != 0) {
error_code = no_error;
@@ -301,19 +639,56 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
#else
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
return nullptr;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
return nullptr;
}
for (auto i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
chunk_sizes[i] = PyLong_AsUnsignedLongLong(size_py);
}
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
num_chunks);
free(p_memHandle);
free(chunk_sizes);
#endif
if (error_code != 0) {
error_code = no_error;

View File

@@ -0,0 +1,109 @@
#pragma once
#ifdef USE_ROCM
////////////////////////////////////////
// For compatibility with CUDA and ROCm
////////////////////////////////////////
#include <hip/hip_runtime_api.h>
extern "C" {
#ifndef CUDA_SUCCESS
#define CUDA_SUCCESS hipSuccess
#endif // CUDA_SUCCESS
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
typedef unsigned long long CUdevice;
typedef hipDeviceptr_t CUdeviceptr;
typedef hipError_t CUresult;
typedef hipCtx_t CUcontext;
typedef hipStream_t CUstream;
typedef hipMemGenericAllocationHandle_t CUmemGenericAllocationHandle;
typedef hipMemAllocationGranularity_flags CUmemAllocationGranularity_flags;
typedef hipMemAllocationProp CUmemAllocationProp;
typedef hipMemAccessDesc CUmemAccessDesc;
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
// Error Handling
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
CUresult cuGetErrorString(CUresult hipError, const char** pStr) {
*pStr = hipGetErrorString(hipError);
return CUDA_SUCCESS;
}
// Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
CUresult cuCtxGetCurrent(CUcontext* ctx) {
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return hipCtxGetCurrent(ctx);
}
CUresult cuCtxSetCurrent(CUcontext ctx) {
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return hipCtxSetCurrent(ctx);
}
// Primary Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
CUresult cuDevicePrimaryCtxRetain(CUcontext* ctx, CUdevice dev) {
return hipDevicePrimaryCtxRetain(ctx, dev);
}
// Virtual Memory Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
CUresult cuMemAddressFree(CUdeviceptr ptr, size_t size) {
return hipMemAddressFree(ptr, size);
}
CUresult cuMemAddressReserve(CUdeviceptr* ptr, size_t size, size_t alignment,
CUdeviceptr addr, unsigned long long flags) {
return hipMemAddressReserve(ptr, size, alignment, addr, flags);
}
CUresult cuMemCreate(CUmemGenericAllocationHandle* handle, size_t size,
const CUmemAllocationProp* prop,
unsigned long long flags) {
return hipMemCreate(handle, size, prop, flags);
}
CUresult cuMemGetAllocationGranularity(
size_t* granularity, const CUmemAllocationProp* prop,
CUmemAllocationGranularity_flags option) {
return hipMemGetAllocationGranularity(granularity, prop, option);
}
CUresult cuMemMap(CUdeviceptr dptr, size_t size, size_t offset,
CUmemGenericAllocationHandle handle,
unsigned long long flags) {
return hipMemMap(dptr, size, offset, handle, flags);
}
CUresult cuMemRelease(CUmemGenericAllocationHandle handle) {
return hipMemRelease(handle);
}
CUresult cuMemSetAccess(CUdeviceptr ptr, size_t size,
const CUmemAccessDesc* desc, size_t count) {
return hipMemSetAccess(ptr, size, desc, count);
}
CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) {
return hipMemUnmap(ptr, size);
}
} // extern "C"
#else
////////////////////////////////////////
// Import CUDA headers for NVIDIA GPUs
////////////////////////////////////////
#include <cuda_runtime_api.h>
#include <cuda.h>
#endif

View File

@@ -11,7 +11,7 @@ Key benefits:
- **Fine-grained control**: Optionally wake up only model weights or KV cache to avoid OOM during weight updates.
!!! note
This feature is only supported on CUDA platform.
This feature is now supported on CUDA and ROCm platform.
!!! note
For more information, see this [Blog Post](https://blog.vllm.ai/2025/10/26/sleep-mode.html).
@@ -116,3 +116,7 @@ curl -X POST 'http://localhost:8000/wake_up?tags=kv_cache'
!!! note
These endpoints are only available when passing `VLLM_SERVER_DEV_MODE=1`.
## Limitation
On ROCm, the virtual memory allocation on ROCm is done through chunked memory allocation. You can control the chunk size through `VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE` (in MB). The default value is set at 256MB. The larger the chunk size the faster the performance. However, setting it too large will cause OOM. So if you encounter OOM when using sleep mode. Try reducing the chunk size. It is recommended to define the chunk size as a power of 2.

View File

@@ -208,6 +208,8 @@ class cmake_build_ext(build_ext):
# Make sure we use the nvcc from CUDA_HOME
if _is_cuda():
cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
elif _is_hip():
cmake_args += [f"-DROCM_PATH={ROCM_HOME}"]
other_cmake_args = os.environ.get("CMAKE_ARGS")
if other_cmake_args:
@@ -628,6 +630,7 @@ ext_modules = []
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
@@ -643,7 +646,6 @@ if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
)
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))

View File

@@ -8,12 +8,13 @@ import torch
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from ..utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_python_error():
"""
Test if Python error occurs when there's low-level
@@ -39,7 +40,7 @@ def test_python_error():
allocator.wake_up()
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_basic_cumem():
# some tensors from default memory pool
shape = (1024, 1024)
@@ -72,7 +73,7 @@ def test_basic_cumem():
assert torch.allclose(output, torch.ones_like(output) * 3)
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool():
@@ -117,7 +118,7 @@ def test_cumem_with_cudagraph():
assert torch.allclose(y, x + 1)
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
@pytest.mark.parametrize(
"model",
[

View File

@@ -264,7 +264,8 @@ class ModelConfig:
merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used."""
enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported)."""
"""Enable sleep mode for the engine (only cuda and
hip platforms are supported)."""
model_impl: str | ModelImpl = "auto"
"""Which implementation of the model to use:\n
- "auto" will try to use the vLLM implementation, if it exists, and fall

View File

@@ -63,7 +63,7 @@ try:
libcudart = CudaRTLibrary()
cumem_available = True
except ModuleNotFoundError:
# rocm platform does not support cumem allocator
# only cuda and rocm platforms support cumem allocator
init_module = None
python_create_and_map = None
python_unmap_and_release = None

View File

@@ -14,6 +14,7 @@ import torch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -105,6 +106,20 @@ class CudaRTLibrary:
),
]
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
cuda_to_hip_mapping = {
"cudaSetDevice": "hipSetDevice",
"cudaDeviceSynchronize": "hipDeviceSynchronize",
"cudaDeviceReset": "hipDeviceReset",
"cudaGetErrorString": "hipGetErrorString",
"cudaMalloc": "hipMalloc",
"cudaFree": "hipFree",
"cudaMemset": "hipMemset",
"cudaMemcpy": "hipMemcpy",
"cudaIpcGetMemHandle": "hipIpcGetMemHandle",
"cudaIpcOpenMemHandle": "hipIpcOpenMemHandle",
}
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
@@ -117,7 +132,13 @@ class CudaRTLibrary:
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
# libcudart is not loaded in the current process, try hip
so_file = find_loaded_library("libamdhip64")
# should be safe to assume now that we are using ROCm
# as the following assertion should error out if the
# libhiprtc library is also not loaded
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, (
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
@@ -130,7 +151,12 @@ class CudaRTLibrary:
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f = getattr(
self.lib,
CudaRTLibrary.cuda_to_hip_mapping[func.name]
if current_platform.is_rocm()
else func.name,
)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f

View File

@@ -18,6 +18,7 @@ if TYPE_CHECKING:
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: str | None = None
LD_LIBRARY_PATH: str | None = None
VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE: int = 256
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: int | None = None
LOCAL_RANK: int = 0
@@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None),
# flag to control the chunk size (in MB) for sleeping memory allocations under ROCm
"VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE": lambda: int(
os.environ.get("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE", "256")
),
# Use separate prefill and decode kernels for V1 attention instead of
# the unified triton kernel.
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: (

View File

@@ -171,7 +171,11 @@ class Platform:
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
def is_sleep_mode_available(self) -> bool:
return self._enum == PlatformEnum.CUDA
# TODO: Actually only mi3xx has the sleep mode support now
# for ROCm, but currently we don't have a way to detect the
# exact GPU model statelessly here. So we return True for
# all ROCm platforms for now.
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):