[KVConnector] Remove v0-related kv connector components such as kv pipe and kv lookup buffer (#29705)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du
2025-12-05 02:20:48 +08:00
committed by GitHub
parent 652ba93da3
commit ece2825a29
13 changed files with 0 additions and 1624 deletions

View File

@@ -1,160 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
# TODO: the test depends on a lot of fields in the current implementation.
# We should have standard interface instead direct field access
def test_run(my_rank, buffer, device):
# buffer should be empty in the beginning
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print(f"My rank: {my_rank}, device: {device}")
# insert
tokens = torch.tensor([1, 2, 3]).to(device)
roi = tokens > 0
if my_rank == 0:
key = 2.0 * torch.ones([5, 6]).to(device)
value = 3.0 * torch.ones([5, 6]).to(device)
placeholder = torch.tensor([1]).to(device)
buffer.insert(tokens, roi, key, value, placeholder)
torch.distributed.barrier()
# drop_select
if my_rank == 1:
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
assert torch.allclose(tokens, tok)
assert torch.allclose(roi, roi_)
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
torch.distributed.barrier()
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print(f"My rank: {my_rank}, Test run passed!")
def stress_test(my_rank, buf, device):
torch.distributed.barrier()
torch.manual_seed(100)
reqs = [
(
torch.rand(100).to(device), # tokens
torch.ones(100).bool().to(device), # roi
torch.rand(100).to(device), # key
torch.rand(100).to(device), # value
torch.rand(100).to(device), # hidden
)
for i in tqdm(range(200))
]
random.seed(my_rank)
random.shuffle(reqs)
torch.distributed.barrier()
n = 0
# the buffer size can only store 100 reqs
# so the sender will occasionally block to wait for the receiver.
for req in tqdm(reqs):
if my_rank == 0:
buf.insert(*req)
else:
tok, roi, k, v, h = req
tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)
if tok_ is None:
assert roi_ is None
assert k_ is None
assert v_ is None
assert h_ is None
n += 1
else:
assert torch.allclose(tok, tok_)
assert torch.allclose(roi, roi_)
assert torch.allclose(k, k_)
assert torch.allclose(v, v_)
assert torch.allclose(h, h_)
print(f"Rank {my_rank} done")
torch.distributed.barrier()
if my_rank == 0:
x = torch.tensor([0])
torch.distributed.recv(x, 1)
# the # of None received is the kv that are not selected
assert x.item() == len(buf.buffer)
# and the size of the buffer should be 2000 * buffer len
print(buf.buffer_size)
assert buf.buffer_size == 1700 * len(buf.buffer)
else:
torch.distributed.send(torch.tensor([n]), 0)
print(f"My rank: {my_rank}, Passed stress test!")
if __name__ == "__main__":
my_rank = int(os.environ["RANK"])
torch.distributed.init_process_group(
backend="gloo",
init_method="tcp://localhost:12398",
world_size=2,
rank=my_rank,
)
print(f"initialized! My rank is {my_rank}")
config = KVTransferConfig(
kv_connector="P2pNcclConnector",
kv_buffer_device="cuda",
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
data_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cuda",
port_offset=0,
)
cpu_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cpu",
port_offset=1,
)
buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)
test_run(my_rank, buffer, data_pipe.device)
stress_test(my_rank, buffer, data_pipe.device)
buffer.close()
data_pipe.close()
cpu_pipe.close()
print("Done")

View File

@@ -1,8 +0,0 @@
#!/bin/bash
RANK=0 python3 test_lookup_buffer.py &
PID0=$!
RANK=1 python3 test_lookup_buffer.py &
PID1=$!
wait $PID0
wait $PID1

View File

@@ -1,62 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
import sys
import pytest
import torch
def run_python_script(script_name, timeout):
script_name = f"kv_transfer/{script_name}"
try:
# Start both processes asynchronously using Popen
process0 = subprocess.Popen(
[sys.executable, script_name],
env={"RANK": "0"}, # Set the RANK environment variable for process 0
stdout=sys.stdout, # Pipe stdout to current stdout
stderr=sys.stderr, # Pipe stderr to current stderr
)
process1 = subprocess.Popen(
[sys.executable, script_name],
env={"RANK": "1"}, # Set the RANK environment variable for process 1
stdout=sys.stdout, # Pipe stdout to current stdout
stderr=sys.stderr, # Pipe stderr to current stderr
)
# Wait for both processes to complete, with a timeout
process0.wait(timeout=timeout)
process1.wait(timeout=timeout)
# Check the return status of both processes
if process0.returncode != 0:
pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}")
if process1.returncode != 0:
pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}")
except subprocess.TimeoutExpired:
# If either process times out, terminate both and fail the test
process0.terminate()
process1.terminate()
pytest.fail(f"Test {script_name} timed out")
except Exception as e:
pytest.fail(f"Test {script_name} failed with error: {str(e)}")
# Define the test cases using pytest's parametrize
@pytest.mark.parametrize(
"script_name,timeout",
[
("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout
("test_send_recv.py", 120), # First test case with a 120-second timeout
],
)
def test_run_python_script(script_name, timeout):
# Check the number of GPUs
if torch.cuda.device_count() < 2:
pytest.skip(f"Skipping test {script_name} because <2 GPUs are available")
# Run the test if there are at least 2 GPUs
run_python_script(script_name, timeout)

View File

@@ -1,154 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
def test_run(my_rank, pipe):
print(f"rank {my_rank} test_run starts....")
# test run
x = torch.tensor([1]).to(pipe.device)
y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device)
if my_rank == 0:
pipe.send_tensor(x)
print(f"rank {my_rank} sent tensor x")
pipe.send_tensor(y)
print(f"rank {my_rank} sent tensor y")
x2 = pipe.recv_tensor()
print(f"rank {my_rank} received x2 = ", x2)
y2 = pipe.recv_tensor()
print(f"rank {my_rank} received y2 = ", y2)
else:
x2 = pipe.recv_tensor()
print(f"rank {my_rank} received x2 = ", x2)
y2 = pipe.recv_tensor()
print(f"rank {my_rank} received y2 = ", y2)
pipe.send_tensor(x)
print(f"rank {my_rank} sent tensor x")
pipe.send_tensor(y)
print(f"rank {my_rank} sent tensor y")
assert torch.allclose(x, x2)
assert torch.allclose(y, y2)
print(f"rank {my_rank} test_run passed!")
def stress_test(my_rank, pipe):
print(f"rank {my_rank} stress_test starts....")
tensors: list[torch.Tensor] = []
torch.distributed.barrier()
torch.manual_seed(0)
for i in tqdm(range(500)):
mean = torch.rand(1).item() * 100
std = torch.rand(1).item() * 100
size = torch.randint(900, 1000, (2,))
x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device)
# 5% probability of sending a None
if torch.rand(1).item() < 0.05:
tensors.append(None)
tensors.append(None)
tensors.append(None)
else:
tensors.append(x)
tensors.append(x.mean().unsqueeze(0))
tensors.append(x.std().unsqueeze(0))
torch.distributed.barrier()
for i in tqdm(range(500)):
if my_rank == int((i % 10) > 3):
pipe.send_tensor(tensors[3 * i])
pipe.send_tensor(tensors[3 * i + 1])
pipe.send_tensor(tensors[3 * i + 2])
else:
x = pipe.recv_tensor()
mean = pipe.recv_tensor()
std = pipe.recv_tensor()
if x is None:
assert mean is None
assert std is None
else:
assert torch.allclose(x, tensors[3 * i])
assert x.mean() == mean[0]
assert x.std() == std[0]
torch.distributed.barrier()
def latency_test(my_rank, pipe, nelement, ntensor):
latencies = []
torch.distributed.barrier()
for i in tqdm(range(500)):
tensors = []
if my_rank == 0:
# create tensor
tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)]
torch.distributed.barrier()
if my_rank == 0:
t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device)
for tensor in tensors:
pipe.send_tensor(tensor)
pipe.send_tensor(t)
else:
for _ in range(ntensor):
pipe.recv_tensor()
t = pipe.recv_tensor()
latencies.append(time.time() - t.item())
torch.distributed.barrier()
print("Latency test passed.")
print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms")
if __name__ == "__main__":
my_rank = int(os.environ["RANK"])
torch.distributed.init_process_group(
backend="gloo",
init_method="tcp://localhost:12398",
world_size=2,
rank=my_rank,
)
config = KVTransferConfig(
kv_connector="P2pNcclConnector",
kv_buffer_device="cuda",
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
)
test_run(my_rank, pipe)
stress_test(my_rank, pipe)
# Use this function if you want to test the latency of pipe impl.
# latency_test(my_rank, pipe, 1024 * 8 * 128, 80)

View File

@@ -1,9 +0,0 @@
#!/bin/bash
RANK=0 python3 test_send_recv.py &
PID0=$!
RANK=1 python3 test_send_recv.py &
PID1=$!
wait $PID0
wait $PID1

View File

@@ -1,179 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.
These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from abc import ABC, abstractmethod
import torch
class KVCacheBufferBase(ABC):
"""
Abstract base class for a KVCache buffer.
"""
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
KVCache buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVLookupBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
) -> list[torch.Tensor | None]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVStoreBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache storage buffer with key-value semantics.
This class provides a simple key-value storage buffer abstract with basic
put/get operations, which enables flexible KVCache transfer granular
control.
The functionality is similar to a distributed key-value store, where:
- Key: A unique string identifier for the cached entry
- Value:
- Tensor to be stored and retrieved
- None (indicating deletion or empty value)
"""
@abstractmethod
def put(
self,
key: str,
value: torch.Tensor | None,
) -> None:
"""Store a key-value pair in the buffer.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
value (Optional[torch.Tensor]): Tensor to be stored.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
key: str,
) -> torch.Tensor | None:
"""Retrieve a value from the buffer by key.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
Returns:
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -1,164 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `MooncakeStore` that allows developers to
think of KV cache transfer operations as putting new KV cache entries
into a remote KVStore-based lookup buffer and getting existing KV caches
from this remote lookup buffer.
"""
import json
import os
from dataclasses import dataclass
import torch
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase
from vllm.logger import init_logger
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
logger = init_logger(__name__)
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get(
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
),
local_buffer_size=config.get(
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeStoreConfig.from_file(config_file_path)
class MooncakeStore(KVStoreBufferBase):
def __init__(
self,
config: VllmConfig,
):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector."
) from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
self.store.setup(
self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
)
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def put(
self,
key: str,
value: torch.Tensor | None,
) -> None:
# A message queue needs to be introduced before making it asynchronous.
if value is not None:
self._put_impl(key, value)
def get(
self,
key: str,
) -> torch.Tensor | None:
# A message queue needs to be introduced before making it asynchronous.
value = self._get_impl(key)
return value
def _put_impl(
self,
key: str,
value: torch.Tensor,
) -> None:
"""Put KVCache to Mooncake Store"""
device_id = value.device.index if value.device.type == "cuda" else -1
device_tensor = torch.tensor(device_id, dtype=torch.int32)
value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor})
try:
self.store.put(key, value_bytes)
except TypeError as err:
logger.error("Failed to put value into Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_impl(
self,
key: str,
) -> torch.Tensor | None:
"""Get KVCache from Mooncake Store"""
try:
data = self.store.get(key)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err
if data:
loaded_tensors = safetensors_load(data)
tensor = loaded_tensors["tensor"]
device_id_tensor = loaded_tensors["device_id"]
device_id = int(device_id_tensor.item())
device = (
torch.device("cuda", device_id)
if device_id >= 0
else torch.device("cpu")
)
return tensor.to(device)
return None

View File

@@ -1,242 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
from collections import deque
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(
self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float
):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: threading.Thread | None = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(
self,
tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: list[torch.Tensor],
):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: list | torch.Tensor | None):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, (
"Please provide the roi when sending drop-select request"
)
roi = roi > 0.5
tokens_roi_recver = [input_tokens, roi]
def is_buffer_available(
tokens_roi_recver: list[torch.Tensor],
) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
return False
with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug("KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e:
if "Connection closed by peer" not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
) -> list[torch.Tensor | None]:
assert self.request_handling_thread is None, (
"drop_select should be called by the KV cache consumer "
"(e.g. the decode vLLM instance)"
)
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = roi > 0.5
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
) -> None:
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler
)
self.request_handling_thread.start()
def close(self):
if (
hasattr(self, "request_handling_thread")
and self.request_handling_thread is not None
):
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)

View File

@@ -1,66 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> torch.Tensor | None:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -1,295 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import struct
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
import torch
import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port
logger = init_logger(__name__)
NONE_INT = -150886311
@dataclass
class MooncakeTransferEngineConfig:
prefill_url: str
decode_url: str
metadata_backend: str | None
metadata_server: str
protocol: str
device_name: str
@staticmethod
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeTransferEngineConfig(
prefill_url=config.get("prefill_url"),
decode_url=config.get("decode_url"),
metadata_backend=config.get("metadata_backend", None),
metadata_server=config.get("metadata_server"),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
)
@staticmethod
def load_from_env() -> "MooncakeTransferEngineConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeTransferEngineConfig.from_file(config_file_path)
class MooncakeTransferEngine:
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
def __init__(self, kv_rank: int, local_rank: int):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector."
) from e
self.engine = TransferEngine()
self.local_rank = local_rank
try:
self.config = MooncakeTransferEngineConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
except ValueError as e:
logger.error(e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
prefill_host, base_prefill_port = split_host_port(self.config.prefill_url)
decode_host, base_decode_port = split_host_port(self.config.decode_url)
# Avoid ports conflict when running prefill and decode on the same node
if prefill_host == decode_host and base_prefill_port == base_decode_port:
base_decode_port = base_decode_port + 100
prefill_port = base_prefill_port + self.local_rank
decode_port = base_decode_port + self.local_rank
self.prefill_url = join_host_port(prefill_host, prefill_port)
self.decode_url = join_host_port(decode_host, decode_port)
self.initialize(
self.prefill_url if kv_rank == 0 else self.decode_url,
self.config.metadata_server,
self.config.protocol,
self.config.device_name,
self.config.metadata_backend,
)
self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url
# Initialize ZeroMQ context and sockets
self.context = zmq.Context() # type: ignore[attr-defined]
self.sender_socket = self.context.socket(zmq.constants.PUSH)
self.receiver_socket = self.context.socket(zmq.constants.PULL)
self.sender_ack = self.context.socket(zmq.constants.PULL)
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
self._setup_metadata_sockets(
kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port
)
def _setup_metadata_sockets(
self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int
) -> None:
"""Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset = p_port + 8 + self.local_rank * 2
d_rank_offset = d_port + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1))
self.receiver_socket.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 1)
)
self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2))
else:
self.receiver_socket.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 1)
)
self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1))
self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2))
def initialize(
self,
local_hostname: str,
metadata_server: str,
protocol: str,
device_name: str,
metadata_backend: str | None,
) -> None:
"""Initialize the mooncake instance."""
if metadata_backend is None:
self.engine.initialize(
local_hostname, metadata_server, protocol, device_name
)
else:
supported_backend = ["etcd", "redis"]
metadata_backend = metadata_backend.lower()
if metadata_backend not in supported_backend:
raise ValueError(
"Mooncake Configuration error. `metadata_backend`"
f" should be one of {supported_backend}."
)
self.engine.initialize_ext(
local_hostname, metadata_server, protocol, device_name, metadata_backend
)
def allocate_managed_buffer(self, length: int) -> int:
"""Allocate a managed buffer of the specified length."""
ret = self.engine.allocate_managed_buffer(length)
if ret <= 0:
logger.error("Allocation Return Error")
raise Exception("Allocation Return Error")
return ret
def free_managed_buffer(self, buffer: int, length: int) -> int:
"""Free a previously allocated managed buffer."""
return self.engine.free_managed_buffer(buffer, length)
def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int:
"""Synchronously transfer data to the specified address."""
ret = self.engine.transfer_sync_read(
self.remote_url, buffer, peer_buffer_address, length
)
if ret < 0:
logger.error("Transfer Return Error")
raise Exception("Transfer Return Error")
return ret
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int:
"""Write bytes to the allocated buffer."""
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
"""Read bytes from the allocated buffer."""
return self.engine.read_bytes_from_buffer(buffer, length)
def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv()
if ack != b"ACK":
logger.error("Failed to receive ACK from the receiver")
self.free_managed_buffer(src_ptr, length)
def send_bytes(self, user_data: bytes) -> None:
"""Send bytes to the remote process."""
length = len(user_data)
src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr), struct.pack("!Q", length)]
)
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process."""
data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup
self.receiver_ack.send(b"ACK")
self.free_managed_buffer(dst_ptr, length)
return ret
class MooncakePipe(KVPipeBase):
"""MooncakeTransferEngine based Pipe implementation."""
def __init__(
self, local_rank: int, config: KVTransferConfig, device: str | None = None
):
"""Initialize the mooncake pipe and set related parameters."""
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank)
self.transport_thread: ThreadPoolExecutor | None = None
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
def _select_device(self, device: str) -> torch.device:
"""Select available device (CUDA or CPU)."""
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def tensor_hash(self, tensor: torch.Tensor) -> int:
"""Calculate the hash value of the tensor."""
return hash(tensor.data_ptr())
def _send_impl(self, tensor: torch.Tensor) -> None:
"""Implement the tensor sending logic using safetensors."""
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
def _recv_impl(self) -> torch.Tensor:
"""Implement the tensor receiving logic using safetensors."""
data = self.transfer_engine.recv_bytes()
return safetensors_load(data)["tensor"].to(self.device)
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""Send tensor to the target process."""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
tensor = tensor if tensor is not None else self.none_tensor
assert len(tensor.shape) > 0
self.transport_thread.submit(self._send_impl, tensor)
def recv_tensor(self) -> torch.Tensor | None:
"""Receive tensor from other processes."""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
tensor = self.transport_thread.submit(self._recv_impl).result()
if tensor.numel() == 1 and tensor.item() == NONE_INT:
return None
else:
return tensor
def close(self) -> None:
"""Cleanup logic when closing the pipe."""
self.transfer_engine.sender_socket.close()
self.transfer_engine.receiver_socket.close()
self.transfer_engine.sender_ack.close()
self.transfer_engine.receiver_ack.close()
self.transfer_engine.context.term() # Terminate the ZMQ context
logger.info("Closed the transfer engine and cleaned up resources.")

View File

@@ -1,285 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
import torch
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
logger = init_logger(__name__)
class BrokenPipeException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Metadata = dict[str, torch.Tensor | None]
class PyNcclPipe(KVPipeBase):
METADATA_LENGTH = 16
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64
def __init__(
self,
local_rank: int,
config: KVTransferConfig,
device: str | None = None,
port_offset: int = 0,
):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
store_timeout=store_timeout,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: ThreadPoolExecutor | None = None
self.buffer_size = 0
self.buffer_size_lock = threading.Lock()
self.buffer_size_thresh = self.config.kv_buffer_size
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> tuple[
Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None]
]:
send: Callable[[torch.Tensor, int], None]
recv: Callable[[torch.Tensor, int], None]
if self.device.type == "cuda":
# use PyNCCL for send / recv
comm = PyNcclCommunicator(group, device=self.local_rank)
comm.disabled = False
send, recv = comm.send, comm.recv # type: ignore
else:
# This send / recv implementation here is NOT intended to transfer
# KV caches (and should NOT be repurposed to transfer KV caches).
# Currently it is only used to transmit control-plane messages
# for PyNcclBuffer.
send = group.send_obj
def my_recv(x, src):
x[...] = group.recv_obj(src)
recv = my_recv
return send, recv
def _select_device(self, device: str):
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata:
"""
Create the metadata as a dictionary based on the input tensor.
Args:
tensor: The input tensor or None if no tensor is provided.
Returns:
metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
if tensor is None:
return {"dtype": None, "shape": None}
else:
return {"dtype": tensor.dtype, "shape": tensor.shape}
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
"""
Create a buffer to receive the tensor based on the provided metadata.
Args:
metadata: A dictionary with keys "dtype" and "shape",
describing the tensor's data type and shape.
Returns:
buffer: A tensor of the specified type and shape,
allocated on `self.device`.
"""
return torch.empty(
metadata["shape"], dtype=metadata["dtype"], device=self.device
)
def _send_metadata(self, metadata: Metadata):
"""
Send the metadata dictionary to the target rank.
Args:
metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
def _recv_metadata(self) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
Returns:
metadata: A dictionary with keys "dtype" and "shape"
describing the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: torch.Tensor | None) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
Args:
tensor: The input tensor to be sent, or `None` if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
if tensor is not None:
self.device_send_func(tensor.to(self.device), self.target_rank_for_send)
def _recv_impl(self) -> torch.Tensor | None:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
buffer: The received tensor, or `None` if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
return buffer
def send_tensor_wrapper(
self, tensor: torch.Tensor | None, tensor_size: int
) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
except Exception as e:
logger.error(
"[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(),
str(tensor),
str(e),
)
import traceback
traceback.print_exc()
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Args:
tensor: The tensor to send, or `None` if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
if tensor is not None:
tensor_size = tensor.element_size() * tensor.numel()
else:
tensor_size = 0
self.block_if_full()
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size)
def recv_tensor(self) -> torch.Tensor | None:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
The received tensor, or `None` if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return tensor
def close(self):
"""
Close the pipe and release associated resources.
"""
if hasattr(self, "transport_thread") and self.transport_thread is not None:
self.transport_thread.shutdown()