mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[KVConnector] Remove v0-related kv connector components such as kv pipe and kv lookup buffer (#29705)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
@@ -1,160 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
||||
|
||||
# TODO: the test depends on a lot of fields in the current implementation.
|
||||
# We should have standard interface instead direct field access
|
||||
|
||||
|
||||
def test_run(my_rank, buffer, device):
|
||||
# buffer should be empty in the beginning
|
||||
if my_rank == 0:
|
||||
assert buffer.buffer_size == 0
|
||||
assert len(buffer.buffer) == 0
|
||||
|
||||
print(f"My rank: {my_rank}, device: {device}")
|
||||
|
||||
# insert
|
||||
tokens = torch.tensor([1, 2, 3]).to(device)
|
||||
roi = tokens > 0
|
||||
if my_rank == 0:
|
||||
key = 2.0 * torch.ones([5, 6]).to(device)
|
||||
value = 3.0 * torch.ones([5, 6]).to(device)
|
||||
|
||||
placeholder = torch.tensor([1]).to(device)
|
||||
|
||||
buffer.insert(tokens, roi, key, value, placeholder)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
# drop_select
|
||||
if my_rank == 1:
|
||||
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
|
||||
assert torch.allclose(tokens, tok)
|
||||
assert torch.allclose(roi, roi_)
|
||||
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
|
||||
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
|
||||
torch.distributed.barrier()
|
||||
|
||||
if my_rank == 0:
|
||||
assert buffer.buffer_size == 0
|
||||
assert len(buffer.buffer) == 0
|
||||
|
||||
print(f"My rank: {my_rank}, Test run passed!")
|
||||
|
||||
|
||||
def stress_test(my_rank, buf, device):
|
||||
torch.distributed.barrier()
|
||||
torch.manual_seed(100)
|
||||
|
||||
reqs = [
|
||||
(
|
||||
torch.rand(100).to(device), # tokens
|
||||
torch.ones(100).bool().to(device), # roi
|
||||
torch.rand(100).to(device), # key
|
||||
torch.rand(100).to(device), # value
|
||||
torch.rand(100).to(device), # hidden
|
||||
)
|
||||
for i in tqdm(range(200))
|
||||
]
|
||||
|
||||
random.seed(my_rank)
|
||||
random.shuffle(reqs)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
n = 0
|
||||
|
||||
# the buffer size can only store 100 reqs
|
||||
# so the sender will occasionally block to wait for the receiver.
|
||||
for req in tqdm(reqs):
|
||||
if my_rank == 0:
|
||||
buf.insert(*req)
|
||||
else:
|
||||
tok, roi, k, v, h = req
|
||||
tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)
|
||||
|
||||
if tok_ is None:
|
||||
assert roi_ is None
|
||||
assert k_ is None
|
||||
assert v_ is None
|
||||
assert h_ is None
|
||||
n += 1
|
||||
else:
|
||||
assert torch.allclose(tok, tok_)
|
||||
assert torch.allclose(roi, roi_)
|
||||
assert torch.allclose(k, k_)
|
||||
assert torch.allclose(v, v_)
|
||||
assert torch.allclose(h, h_)
|
||||
print(f"Rank {my_rank} done")
|
||||
torch.distributed.barrier()
|
||||
|
||||
if my_rank == 0:
|
||||
x = torch.tensor([0])
|
||||
torch.distributed.recv(x, 1)
|
||||
# the # of None received is the kv that are not selected
|
||||
assert x.item() == len(buf.buffer)
|
||||
# and the size of the buffer should be 2000 * buffer len
|
||||
print(buf.buffer_size)
|
||||
assert buf.buffer_size == 1700 * len(buf.buffer)
|
||||
else:
|
||||
torch.distributed.send(torch.tensor([n]), 0)
|
||||
|
||||
print(f"My rank: {my_rank}, Passed stress test!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
my_rank = int(os.environ["RANK"])
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="gloo",
|
||||
init_method="tcp://localhost:12398",
|
||||
world_size=2,
|
||||
rank=my_rank,
|
||||
)
|
||||
|
||||
print(f"initialized! My rank is {my_rank}")
|
||||
|
||||
config = KVTransferConfig(
|
||||
kv_connector="P2pNcclConnector",
|
||||
kv_buffer_device="cuda",
|
||||
kv_buffer_size=1e9,
|
||||
kv_rank=my_rank,
|
||||
kv_role="kv_both", # this arg doesn't matter in this test
|
||||
kv_parallel_size=2,
|
||||
kv_ip="127.0.0.1",
|
||||
kv_port=12345,
|
||||
)
|
||||
|
||||
data_pipe = PyNcclPipe(
|
||||
local_rank=my_rank,
|
||||
config=config,
|
||||
device="cuda",
|
||||
port_offset=0,
|
||||
)
|
||||
cpu_pipe = PyNcclPipe(
|
||||
local_rank=my_rank,
|
||||
config=config,
|
||||
device="cpu",
|
||||
port_offset=1,
|
||||
)
|
||||
|
||||
buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)
|
||||
|
||||
test_run(my_rank, buffer, data_pipe.device)
|
||||
|
||||
stress_test(my_rank, buffer, data_pipe.device)
|
||||
|
||||
buffer.close()
|
||||
data_pipe.close()
|
||||
cpu_pipe.close()
|
||||
print("Done")
|
||||
@@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
RANK=0 python3 test_lookup_buffer.py &
|
||||
PID0=$!
|
||||
RANK=1 python3 test_lookup_buffer.py &
|
||||
PID1=$!
|
||||
|
||||
wait $PID0
|
||||
wait $PID1
|
||||
@@ -1,62 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def run_python_script(script_name, timeout):
|
||||
script_name = f"kv_transfer/{script_name}"
|
||||
try:
|
||||
# Start both processes asynchronously using Popen
|
||||
process0 = subprocess.Popen(
|
||||
[sys.executable, script_name],
|
||||
env={"RANK": "0"}, # Set the RANK environment variable for process 0
|
||||
stdout=sys.stdout, # Pipe stdout to current stdout
|
||||
stderr=sys.stderr, # Pipe stderr to current stderr
|
||||
)
|
||||
|
||||
process1 = subprocess.Popen(
|
||||
[sys.executable, script_name],
|
||||
env={"RANK": "1"}, # Set the RANK environment variable for process 1
|
||||
stdout=sys.stdout, # Pipe stdout to current stdout
|
||||
stderr=sys.stderr, # Pipe stderr to current stderr
|
||||
)
|
||||
|
||||
# Wait for both processes to complete, with a timeout
|
||||
process0.wait(timeout=timeout)
|
||||
process1.wait(timeout=timeout)
|
||||
|
||||
# Check the return status of both processes
|
||||
if process0.returncode != 0:
|
||||
pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}")
|
||||
if process1.returncode != 0:
|
||||
pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
# If either process times out, terminate both and fail the test
|
||||
process0.terminate()
|
||||
process1.terminate()
|
||||
pytest.fail(f"Test {script_name} timed out")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Test {script_name} failed with error: {str(e)}")
|
||||
|
||||
|
||||
# Define the test cases using pytest's parametrize
|
||||
@pytest.mark.parametrize(
|
||||
"script_name,timeout",
|
||||
[
|
||||
("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout
|
||||
("test_send_recv.py", 120), # First test case with a 120-second timeout
|
||||
],
|
||||
)
|
||||
def test_run_python_script(script_name, timeout):
|
||||
# Check the number of GPUs
|
||||
if torch.cuda.device_count() < 2:
|
||||
pytest.skip(f"Skipping test {script_name} because <2 GPUs are available")
|
||||
|
||||
# Run the test if there are at least 2 GPUs
|
||||
run_python_script(script_name, timeout)
|
||||
@@ -1,154 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
||||
|
||||
|
||||
def test_run(my_rank, pipe):
|
||||
print(f"rank {my_rank} test_run starts....")
|
||||
# test run
|
||||
x = torch.tensor([1]).to(pipe.device)
|
||||
y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device)
|
||||
if my_rank == 0:
|
||||
pipe.send_tensor(x)
|
||||
print(f"rank {my_rank} sent tensor x")
|
||||
pipe.send_tensor(y)
|
||||
print(f"rank {my_rank} sent tensor y")
|
||||
x2 = pipe.recv_tensor()
|
||||
print(f"rank {my_rank} received x2 = ", x2)
|
||||
y2 = pipe.recv_tensor()
|
||||
print(f"rank {my_rank} received y2 = ", y2)
|
||||
|
||||
else:
|
||||
x2 = pipe.recv_tensor()
|
||||
print(f"rank {my_rank} received x2 = ", x2)
|
||||
y2 = pipe.recv_tensor()
|
||||
print(f"rank {my_rank} received y2 = ", y2)
|
||||
pipe.send_tensor(x)
|
||||
print(f"rank {my_rank} sent tensor x")
|
||||
pipe.send_tensor(y)
|
||||
print(f"rank {my_rank} sent tensor y")
|
||||
|
||||
assert torch.allclose(x, x2)
|
||||
assert torch.allclose(y, y2)
|
||||
|
||||
print(f"rank {my_rank} test_run passed!")
|
||||
|
||||
|
||||
def stress_test(my_rank, pipe):
|
||||
print(f"rank {my_rank} stress_test starts....")
|
||||
|
||||
tensors: list[torch.Tensor] = []
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.manual_seed(0)
|
||||
|
||||
for i in tqdm(range(500)):
|
||||
mean = torch.rand(1).item() * 100
|
||||
std = torch.rand(1).item() * 100
|
||||
size = torch.randint(900, 1000, (2,))
|
||||
x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device)
|
||||
|
||||
# 5% probability of sending a None
|
||||
if torch.rand(1).item() < 0.05:
|
||||
tensors.append(None)
|
||||
tensors.append(None)
|
||||
tensors.append(None)
|
||||
else:
|
||||
tensors.append(x)
|
||||
tensors.append(x.mean().unsqueeze(0))
|
||||
tensors.append(x.std().unsqueeze(0))
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
for i in tqdm(range(500)):
|
||||
if my_rank == int((i % 10) > 3):
|
||||
pipe.send_tensor(tensors[3 * i])
|
||||
pipe.send_tensor(tensors[3 * i + 1])
|
||||
pipe.send_tensor(tensors[3 * i + 2])
|
||||
else:
|
||||
x = pipe.recv_tensor()
|
||||
mean = pipe.recv_tensor()
|
||||
std = pipe.recv_tensor()
|
||||
|
||||
if x is None:
|
||||
assert mean is None
|
||||
assert std is None
|
||||
else:
|
||||
assert torch.allclose(x, tensors[3 * i])
|
||||
assert x.mean() == mean[0]
|
||||
assert x.std() == std[0]
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def latency_test(my_rank, pipe, nelement, ntensor):
|
||||
latencies = []
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
for i in tqdm(range(500)):
|
||||
tensors = []
|
||||
|
||||
if my_rank == 0:
|
||||
# create tensor
|
||||
tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)]
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if my_rank == 0:
|
||||
t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device)
|
||||
for tensor in tensors:
|
||||
pipe.send_tensor(tensor)
|
||||
pipe.send_tensor(t)
|
||||
else:
|
||||
for _ in range(ntensor):
|
||||
pipe.recv_tensor()
|
||||
t = pipe.recv_tensor()
|
||||
latencies.append(time.time() - t.item())
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
print("Latency test passed.")
|
||||
print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
my_rank = int(os.environ["RANK"])
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="gloo",
|
||||
init_method="tcp://localhost:12398",
|
||||
world_size=2,
|
||||
rank=my_rank,
|
||||
)
|
||||
|
||||
config = KVTransferConfig(
|
||||
kv_connector="P2pNcclConnector",
|
||||
kv_buffer_device="cuda",
|
||||
kv_buffer_size=1e9,
|
||||
kv_rank=my_rank,
|
||||
kv_role="kv_both", # this arg doesn't matter in this test
|
||||
kv_parallel_size=2,
|
||||
kv_ip="127.0.0.1",
|
||||
kv_port=12345,
|
||||
)
|
||||
|
||||
pipe = PyNcclPipe(
|
||||
local_rank=my_rank,
|
||||
config=config,
|
||||
)
|
||||
|
||||
test_run(my_rank, pipe)
|
||||
|
||||
stress_test(my_rank, pipe)
|
||||
|
||||
# Use this function if you want to test the latency of pipe impl.
|
||||
# latency_test(my_rank, pipe, 1024 * 8 * 128, 80)
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
RANK=0 python3 test_send_recv.py &
|
||||
PID0=$!
|
||||
RANK=1 python3 test_send_recv.py &
|
||||
PID1=$!
|
||||
|
||||
wait $PID0
|
||||
wait $PID1
|
||||
@@ -1,179 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
||||
from the lookup buffer.
|
||||
|
||||
This file also contains a new class `KVStoreBufferBase` that allows developers
|
||||
to manage the KVCache buffer as a simple key-value storage buffer with basic
|
||||
put/get operations.
|
||||
|
||||
These classes above are abstracted behind class `KVCacheBufferBase`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVCacheBufferBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KVCache buffer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
KVCache buffer when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVLookupBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache lookup buffer.
|
||||
|
||||
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
||||
|
||||
The key of the lookup buffer:
|
||||
- input_tokens: token IDs of the request
|
||||
- roi: a binary mask on top of input_tokens.
|
||||
- Purpose of roi: Since KV cache may only be available for a subset of
|
||||
tokens in the input (for example, when vLLM is connected to an external
|
||||
KV cache service), roi specifies the subset of tokens that the KV cache
|
||||
is associated with.
|
||||
- NOTE: roi can be further extended to describe which part of KV the
|
||||
current process is holding (each process may only hold a part of KV
|
||||
due to TP and PP). This is not implemented for now.
|
||||
|
||||
The value of the lookup buffer:
|
||||
- key: the key tensor in the KV cache
|
||||
- value: the value tensor in the KV cache
|
||||
- hidden: the final hidden state generated by model forwarding. This allows
|
||||
vLLM to bypass further model forwarding by transmitting the hidden state.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def insert(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
) -> None:
|
||||
"""Insert into the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statement
|
||||
```
|
||||
buffer[input_tokens, roi] = [key, value, hidden]
|
||||
```
|
||||
|
||||
FIXME: in the future, we should only have two arguments, key and value,
|
||||
where key is a tensor dict and value is a tensor dict.
|
||||
|
||||
FIXME: we should transmit both sampler outputs and the hidden states.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
key (torch.Tensor): The key tensor in the KV cache.
|
||||
value (torch.Tensor): The value tensor in the KV cache.
|
||||
hidden (torch.Tensor): The final hidden state tensor generated
|
||||
during model forwarding to bypass model
|
||||
forwarding.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def drop_select(
|
||||
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
|
||||
) -> list[torch.Tensor | None]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
```
|
||||
ret = buffer.pop(input_tokens, roi)
|
||||
return ret
|
||||
```
|
||||
|
||||
If `input_tokens` and `roi` is `None`, it means selecting any of the
|
||||
KV caches in the buffer, return, and remove it from the buffer, useful
|
||||
when offloading KV cache to KV cache storage service.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVStoreBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache storage buffer with key-value semantics.
|
||||
This class provides a simple key-value storage buffer abstract with basic
|
||||
put/get operations, which enables flexible KVCache transfer granular
|
||||
control.
|
||||
|
||||
The functionality is similar to a distributed key-value store, where:
|
||||
- Key: A unique string identifier for the cached entry
|
||||
- Value:
|
||||
- Tensor to be stored and retrieved
|
||||
- None (indicating deletion or empty value)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: torch.Tensor | None,
|
||||
) -> None:
|
||||
"""Store a key-value pair in the buffer.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
value (Optional[torch.Tensor]): Tensor to be stored.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> torch.Tensor | None:
|
||||
"""Retrieve a value from the buffer by key.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,164 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `MooncakeStore` that allows developers to
|
||||
think of KV cache transfer operations as putting new KV cache entries
|
||||
into a remote KVStore-based lookup buffer and getting existing KV caches
|
||||
from this remote lookup buffer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get(
|
||||
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
||||
),
|
||||
local_buffer_size=config.get(
|
||||
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
||||
),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
||||
)
|
||||
return MooncakeStoreConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeStore(KVStoreBufferBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
|
||||
self.store.setup(
|
||||
self.config.local_hostname,
|
||||
self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
# MooncakeDistributedStore will automatically call the destructor, so
|
||||
# it is unnecessary to close it manually.
|
||||
pass
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: torch.Tensor | None,
|
||||
) -> None:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
if value is not None:
|
||||
self._put_impl(key, value)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> torch.Tensor | None:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
value = self._get_impl(key)
|
||||
return value
|
||||
|
||||
def _put_impl(
|
||||
self,
|
||||
key: str,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
"""Put KVCache to Mooncake Store"""
|
||||
device_id = value.device.index if value.device.type == "cuda" else -1
|
||||
device_tensor = torch.tensor(device_id, dtype=torch.int32)
|
||||
value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor})
|
||||
try:
|
||||
self.store.put(key, value_bytes)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to put value into Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Put Type Error.") from err
|
||||
|
||||
def _get_impl(
|
||||
self,
|
||||
key: str,
|
||||
) -> torch.Tensor | None:
|
||||
"""Get KVCache from Mooncake Store"""
|
||||
try:
|
||||
data = self.store.get(key)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to get value from Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Get Type Error.") from err
|
||||
|
||||
if data:
|
||||
loaded_tensors = safetensors_load(data)
|
||||
tensor = loaded_tensors["tensor"]
|
||||
device_id_tensor = loaded_tensors["device_id"]
|
||||
device_id = int(device_id_tensor.item())
|
||||
device = (
|
||||
torch.device("cuda", device_id)
|
||||
if device_id >= 0
|
||||
else torch.device("cpu")
|
||||
)
|
||||
return tensor.to(device)
|
||||
|
||||
return None
|
||||
@@ -1,242 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Implements a distributed key-value (KV) cache transfer mechanism.
|
||||
|
||||
Key Features:
|
||||
- Distributed KV cache transmission using PyNccl pipes.
|
||||
- Non-blocking `insert`, blocking `drop_select`.
|
||||
- Use CPU signal pipe to avoid racing condition
|
||||
- Handles buffer size constraints and provide backpressure mechanism to
|
||||
stop the prefill instance when the decode instance is slow.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleBuffer(KVLookupBufferBase):
|
||||
def __init__(
|
||||
self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float
|
||||
):
|
||||
"""
|
||||
signal_pipe: on CPU
|
||||
|
||||
NOTE: on-device recv will block all threads in the process, making the
|
||||
KV cache producer unable to listen to new request while transmitting
|
||||
KV cache. Luckily CPU recv only blocks the current thread so we use
|
||||
CPU recv to listen to new request.
|
||||
|
||||
data_pipe: on device (e.g. GPU)
|
||||
"""
|
||||
|
||||
self.buffer: deque[list[torch.Tensor]] = deque()
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = buffer_size_thresh
|
||||
self.buffer_cv = threading.Condition()
|
||||
self.signal_pipe = signal_pipe
|
||||
self.data_pipe = data_pipe
|
||||
self.request_handling_thread: threading.Thread | None = None
|
||||
|
||||
self.normal_signal = torch.tensor([0], device="cpu")
|
||||
self.end_signal = None
|
||||
|
||||
def _matches(
|
||||
self,
|
||||
tokens_roi_sender: list[torch.Tensor],
|
||||
tokens_roi_recver: list[torch.Tensor],
|
||||
):
|
||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
||||
|
||||
tokens_sender = tokens_roi_sender[0]
|
||||
tokens_recver = tokens_roi_recver[0]
|
||||
roi_sender = tokens_roi_sender[1]
|
||||
roi_recver = tokens_roi_recver[1]
|
||||
|
||||
if tokens_recver is None:
|
||||
# consumer sends an empty request
|
||||
# semantics: DROP SELECT * LIMIT 1
|
||||
# so any of the data in the buffer can be drop-selected
|
||||
return True
|
||||
|
||||
# Assuming that roi is a binary mask on tokens
|
||||
tokens_sender = tokens_sender[roi_sender]
|
||||
tokens_recver = tokens_recver[roi_recver]
|
||||
|
||||
# simple common prefix matching
|
||||
min_length = min(len(tokens_sender), len(tokens_recver))
|
||||
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
|
||||
return min_length
|
||||
|
||||
return 0
|
||||
|
||||
def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None:
|
||||
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
self.data_pipe.send_tensor(tensor)
|
||||
|
||||
def _get_element_size(self, data: list | torch.Tensor | None):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.element_size() * data.numel()
|
||||
if not data:
|
||||
# cannot perform `not data` on a tensor
|
||||
# so this check needs to go after the check above
|
||||
return 0
|
||||
|
||||
raise AssertionError(f"Unknown data type {type(data)}")
|
||||
|
||||
def _add_to_buffer(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
):
|
||||
if isinstance(input_tokens, torch.Tensor):
|
||||
input_tokens = input_tokens.clone()
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone()
|
||||
if isinstance(key, torch.Tensor):
|
||||
key = key.clone()
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.clone()
|
||||
if isinstance(hidden, torch.Tensor):
|
||||
hidden = hidden.clone()
|
||||
|
||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
||||
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
||||
|
||||
with self.buffer_cv:
|
||||
if self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
# log outside the while loop to avoid this message being logged
|
||||
# repeatedly.
|
||||
logger.debug("KV transfer buffer is full. Handling...")
|
||||
while self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
self.buffer_cv.wait()
|
||||
|
||||
self.buffer_size += data_size
|
||||
self.buffer.append(buffer_item)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
def _is_end_signal(self, signal):
|
||||
return signal is None
|
||||
|
||||
def drop_select_handler(self):
|
||||
try:
|
||||
while True:
|
||||
signal = self.signal_pipe.recv_tensor()
|
||||
if self._is_end_signal(signal):
|
||||
logger.info("Received end signal!")
|
||||
break
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
assert roi is not None, (
|
||||
"Please provide the roi when sending drop-select request"
|
||||
)
|
||||
roi = roi > 0.5
|
||||
tokens_roi_recver = [input_tokens, roi]
|
||||
|
||||
def is_buffer_available(
|
||||
tokens_roi_recver: list[torch.Tensor],
|
||||
) -> bool:
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
for _ in range(len(self.buffer)):
|
||||
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
|
||||
return True
|
||||
# rotate the element we just accessed to the end
|
||||
self.buffer.rotate(-1)
|
||||
return False
|
||||
|
||||
with self.buffer_cv:
|
||||
while not is_buffer_available(tokens_roi_recver):
|
||||
logger.debug("KV transfer buffer is not available. Waiting...")
|
||||
self.buffer_cv.wait()
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
for tensor in matched_item:
|
||||
self._send_tensor_and_dec_size(tensor)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
except RuntimeError as e:
|
||||
if "Connection closed by peer" not in str(e):
|
||||
raise e
|
||||
|
||||
logger.debug("Closing drop_select_handler")
|
||||
|
||||
def drop_select(
|
||||
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
|
||||
) -> list[torch.Tensor | None]:
|
||||
assert self.request_handling_thread is None, (
|
||||
"drop_select should be called by the KV cache consumer "
|
||||
"(e.g. the decode vLLM instance)"
|
||||
)
|
||||
|
||||
if isinstance(input_tokens, torch.Tensor):
|
||||
input_tokens = input_tokens.clone()
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone().float()
|
||||
|
||||
self.signal_pipe.send_tensor(self.normal_signal)
|
||||
self.data_pipe.send_tensor(input_tokens)
|
||||
self.data_pipe.send_tensor(roi)
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
if roi is not None:
|
||||
# convert from float tensor to bool tensor
|
||||
# as PyNccl does not support sending bool tensor
|
||||
roi = roi > 0.5
|
||||
key = self.data_pipe.recv_tensor()
|
||||
value = self.data_pipe.recv_tensor()
|
||||
hidden = self.data_pipe.recv_tensor()
|
||||
|
||||
return [input_tokens, roi, key, value, hidden]
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
) -> None:
|
||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
||||
|
||||
# when calling the insert, the current process is a sender
|
||||
# need to launch the request handler and start listening to request.
|
||||
if self.request_handling_thread is None:
|
||||
self.request_handling_thread = threading.Thread(
|
||||
target=self.drop_select_handler
|
||||
)
|
||||
self.request_handling_thread.start()
|
||||
|
||||
def close(self):
|
||||
if (
|
||||
hasattr(self, "request_handling_thread")
|
||||
and self.request_handling_thread is not None
|
||||
):
|
||||
self.request_handling_thread.join()
|
||||
|
||||
else:
|
||||
# TODO: have a explicit close signal and have a explicit way to
|
||||
# check if it's requester
|
||||
self.signal_pipe.send_tensor(self.end_signal)
|
||||
@@ -1,66 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file defines an interface `KVPipeBase`
|
||||
that provides an abstraction for sending and receiving tensors, or None, via
|
||||
distributed communications.
|
||||
|
||||
All classes instantiated from this interface are assumed to be a FIFO pipe.
|
||||
|
||||
If your distributed communication platform already supports key-value lookup,
|
||||
you can bypass this interface and directly start from `kv_lookup_buffer`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVPipeBase(ABC):
|
||||
"""
|
||||
This class provides an interface for sending and receiving tensors, or
|
||||
None, by distributed communications.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
||||
"""Send a tensor, or None, via the pipe.
|
||||
|
||||
Need to support sending None -- important for error handling.
|
||||
|
||||
TODO: add a `key` argument so that we can use traditional
|
||||
key-value database as the distributed communication mechanism behind
|
||||
the pipe.
|
||||
|
||||
Args:
|
||||
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_tensor(self) -> torch.Tensor | None:
|
||||
"""Receive a tensor (can be None) from the pipeline.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: The tensor received from the pipeline. Can
|
||||
be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the pipeline and release resources.
|
||||
|
||||
This method is responsible for closing the communication pipeline
|
||||
and releasing any resources associated with it.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,295 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeTransferEngineConfig:
|
||||
prefill_url: str
|
||||
decode_url: str
|
||||
metadata_backend: str | None
|
||||
metadata_server: str
|
||||
protocol: str
|
||||
device_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeTransferEngineConfig(
|
||||
prefill_url=config.get("prefill_url"),
|
||||
decode_url=config.get("decode_url"),
|
||||
metadata_backend=config.get("metadata_backend", None),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeTransferEngineConfig":
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
||||
)
|
||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
||||
|
||||
def __init__(self, kv_rank: int, local_rank: int):
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector."
|
||||
) from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.local_rank = local_rank
|
||||
|
||||
try:
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = split_host_port(self.config.prefill_url)
|
||||
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and base_prefill_port == base_decode_port:
|
||||
base_decode_port = base_decode_port + 100
|
||||
|
||||
prefill_port = base_prefill_port + self.local_rank
|
||||
decode_port = base_decode_port + self.local_rank
|
||||
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
||||
self.decode_url = join_host_port(decode_host, decode_port)
|
||||
|
||||
self.initialize(
|
||||
self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.metadata_backend,
|
||||
)
|
||||
|
||||
self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url
|
||||
|
||||
# Initialize ZeroMQ context and sockets
|
||||
self.context = zmq.Context() # type: ignore[attr-defined]
|
||||
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
||||
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
||||
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
||||
|
||||
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
||||
self._setup_metadata_sockets(
|
||||
kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port
|
||||
)
|
||||
|
||||
def _setup_metadata_sockets(
|
||||
self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int
|
||||
) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = p_port + 8 + self.local_rank * 2
|
||||
d_rank_offset = d_port + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 1)
|
||||
)
|
||||
self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
else:
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 1)
|
||||
)
|
||||
self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||
self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
local_hostname: str,
|
||||
metadata_server: str,
|
||||
protocol: str,
|
||||
device_name: str,
|
||||
metadata_backend: str | None,
|
||||
) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
if metadata_backend is None:
|
||||
self.engine.initialize(
|
||||
local_hostname, metadata_server, protocol, device_name
|
||||
)
|
||||
else:
|
||||
supported_backend = ["etcd", "redis"]
|
||||
metadata_backend = metadata_backend.lower()
|
||||
if metadata_backend not in supported_backend:
|
||||
raise ValueError(
|
||||
"Mooncake Configuration error. `metadata_backend`"
|
||||
f" should be one of {supported_backend}."
|
||||
)
|
||||
|
||||
self.engine.initialize_ext(
|
||||
local_hostname, metadata_server, protocol, device_name, metadata_backend
|
||||
)
|
||||
|
||||
def allocate_managed_buffer(self, length: int) -> int:
|
||||
"""Allocate a managed buffer of the specified length."""
|
||||
ret = self.engine.allocate_managed_buffer(length)
|
||||
if ret <= 0:
|
||||
logger.error("Allocation Return Error")
|
||||
raise Exception("Allocation Return Error")
|
||||
return ret
|
||||
|
||||
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
||||
"""Free a previously allocated managed buffer."""
|
||||
return self.engine.free_managed_buffer(buffer, length)
|
||||
|
||||
def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
ret = self.engine.transfer_sync_read(
|
||||
self.remote_url, buffer, peer_buffer_address, length
|
||||
)
|
||||
if ret < 0:
|
||||
logger.error("Transfer Return Error")
|
||||
raise Exception("Transfer Return Error")
|
||||
return ret
|
||||
|
||||
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int:
|
||||
"""Write bytes to the allocated buffer."""
|
||||
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
|
||||
|
||||
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
||||
"""Read bytes from the allocated buffer."""
|
||||
return self.engine.read_bytes_from_buffer(buffer, length)
|
||||
|
||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||
"""Asynchronously wait for ACK from the receiver."""
|
||||
ack = self.sender_ack.recv()
|
||||
if ack != b"ACK":
|
||||
logger.error("Failed to receive ACK from the receiver")
|
||||
|
||||
self.free_managed_buffer(src_ptr, length)
|
||||
|
||||
def send_bytes(self, user_data: bytes) -> None:
|
||||
"""Send bytes to the remote process."""
|
||||
length = len(user_data)
|
||||
src_ptr = self.allocate_managed_buffer(length)
|
||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||
self.sender_socket.send_multipart(
|
||||
[struct.pack("!Q", src_ptr), struct.pack("!Q", length)]
|
||||
)
|
||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||
|
||||
def recv_bytes(self) -> bytes:
|
||||
"""Receive bytes from the remote process."""
|
||||
data = self.receiver_socket.recv_multipart()
|
||||
src_ptr = struct.unpack("!Q", data[0])[0]
|
||||
length = struct.unpack("!Q", data[1])[0]
|
||||
dst_ptr = self.allocate_managed_buffer(length)
|
||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||
|
||||
# Buffer cleanup
|
||||
self.receiver_ack.send(b"ACK")
|
||||
self.free_managed_buffer(dst_ptr, length)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class MooncakePipe(KVPipeBase):
|
||||
"""MooncakeTransferEngine based Pipe implementation."""
|
||||
|
||||
def __init__(
|
||||
self, local_rank: int, config: KVTransferConfig, device: str | None = None
|
||||
):
|
||||
"""Initialize the mooncake pipe and set related parameters."""
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
assert self.kv_rank is not None
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank)
|
||||
self.transport_thread: ThreadPoolExecutor | None = None
|
||||
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
||||
|
||||
def _select_device(self, device: str) -> torch.device:
|
||||
"""Select available device (CUDA or CPU)."""
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
return hash(tensor.data_ptr())
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor) -> None:
|
||||
"""Implement the tensor sending logic using safetensors."""
|
||||
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor:
|
||||
"""Implement the tensor receiving logic using safetensors."""
|
||||
data = self.transfer_engine.recv_bytes()
|
||||
return safetensors_load(data)["tensor"].to(self.device)
|
||||
|
||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
||||
"""Send tensor to the target process."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = tensor if tensor is not None else self.none_tensor
|
||||
assert len(tensor.shape) > 0
|
||||
self.transport_thread.submit(self._send_impl, tensor)
|
||||
|
||||
def recv_tensor(self) -> torch.Tensor | None:
|
||||
"""Receive tensor from other processes."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = self.transport_thread.submit(self._recv_impl).result()
|
||||
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
||||
return None
|
||||
else:
|
||||
return tensor
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup logic when closing the pipe."""
|
||||
self.transfer_engine.sender_socket.close()
|
||||
self.transfer_engine.receiver_socket.close()
|
||||
self.transfer_engine.sender_ack.close()
|
||||
self.transfer_engine.receiver_ack.close()
|
||||
self.transfer_engine.context.term() # Terminate the ZMQ context
|
||||
logger.info("Closed the transfer engine and cleaned up resources.")
|
||||
@@ -1,285 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This module implements a PyNccl pipe for sending and receiving
|
||||
Optional[torch.Tensor] between distributed ranks with advanced
|
||||
communication features.
|
||||
|
||||
Key Features:
|
||||
- Supports sending and receiving tensors with metadata
|
||||
- Handles both CUDA and CPU device communications
|
||||
- Implements a non-blocking tensor transfer mechanism
|
||||
- Manages buffer size and provides backpressure control
|
||||
- Supports distributed process groups with configurable parameters
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BrokenPipeException(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
Metadata = dict[str, torch.Tensor | None]
|
||||
|
||||
|
||||
class PyNcclPipe(KVPipeBase):
|
||||
METADATA_LENGTH = 16
|
||||
MAX_TENSOR_DIMENSIONS = 14
|
||||
METADATA_DTYPE = torch.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: str | None = None,
|
||||
port_offset: int = 0,
|
||||
):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
assert self.kv_rank is not None
|
||||
self.kv_parallel_size = self.config.kv_parallel_size
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# build distributed connection and send/recv implementation
|
||||
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||
self.group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port + port_offset,
|
||||
rank=self.kv_rank,
|
||||
world_size=self.kv_parallel_size,
|
||||
store_timeout=store_timeout,
|
||||
)
|
||||
# add a barrier to make sure the connection is initiated properly
|
||||
self.group.barrier()
|
||||
impl = self._get_device_send_recv_impl(self.group)
|
||||
self.device_send_func, self.device_recv_func = impl
|
||||
# set target rank
|
||||
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
||||
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
||||
|
||||
# transportation-related variables
|
||||
self.transport_thread: ThreadPoolExecutor | None = None
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_lock = threading.Lock()
|
||||
self.buffer_size_thresh = self.config.kv_buffer_size
|
||||
|
||||
def _get_device_send_recv_impl(
|
||||
self, group: StatelessProcessGroup
|
||||
) -> tuple[
|
||||
Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None]
|
||||
]:
|
||||
send: Callable[[torch.Tensor, int], None]
|
||||
recv: Callable[[torch.Tensor, int], None]
|
||||
if self.device.type == "cuda":
|
||||
# use PyNCCL for send / recv
|
||||
comm = PyNcclCommunicator(group, device=self.local_rank)
|
||||
comm.disabled = False
|
||||
send, recv = comm.send, comm.recv # type: ignore
|
||||
else:
|
||||
# This send / recv implementation here is NOT intended to transfer
|
||||
# KV caches (and should NOT be repurposed to transfer KV caches).
|
||||
# Currently it is only used to transmit control-plane messages
|
||||
# for PyNcclBuffer.
|
||||
send = group.send_obj
|
||||
|
||||
def my_recv(x, src):
|
||||
x[...] = group.recv_obj(src)
|
||||
|
||||
recv = my_recv
|
||||
|
||||
return send, recv
|
||||
|
||||
def _select_device(self, device: str):
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata:
|
||||
"""
|
||||
Create the metadata as a dictionary based on the input tensor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor or None if no tensor is provided.
|
||||
|
||||
Returns:
|
||||
metadata: A dictionary with the following keys:
|
||||
- "dtype": The data type of the tensor or None.
|
||||
- "shape": The shape of the tensor or None.
|
||||
"""
|
||||
if tensor is None:
|
||||
return {"dtype": None, "shape": None}
|
||||
else:
|
||||
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
||||
|
||||
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
||||
"""
|
||||
Create a buffer to receive the tensor based on the provided metadata.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary with keys "dtype" and "shape",
|
||||
describing the tensor's data type and shape.
|
||||
|
||||
Returns:
|
||||
buffer: A tensor of the specified type and shape,
|
||||
allocated on `self.device`.
|
||||
"""
|
||||
return torch.empty(
|
||||
metadata["shape"], dtype=metadata["dtype"], device=self.device
|
||||
)
|
||||
|
||||
def _send_metadata(self, metadata: Metadata):
|
||||
"""
|
||||
Send the metadata dictionary to the target rank.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary with keys "dtype" and "shape".
|
||||
"""
|
||||
self.group.send_obj(metadata, self.target_rank_for_send)
|
||||
|
||||
def _recv_metadata(self) -> Metadata:
|
||||
"""
|
||||
Receive the metadata dictionary from the target rank.
|
||||
|
||||
Returns:
|
||||
metadata: A dictionary with keys "dtype" and "shape"
|
||||
describing the tensor.
|
||||
"""
|
||||
return self.group.recv_obj(self.target_rank_for_recv)
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor | None) -> None:
|
||||
"""
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
target rank.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to be sent, or `None` if no tensor is
|
||||
being sent.
|
||||
"""
|
||||
metadata = self._make_metadata(tensor)
|
||||
self._send_metadata(metadata)
|
||||
if tensor is not None:
|
||||
self.device_send_func(tensor.to(self.device), self.target_rank_for_send)
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor | None:
|
||||
"""
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
the target rank.
|
||||
|
||||
Returns:
|
||||
buffer: The received tensor, or `None` if no tensor is received.
|
||||
"""
|
||||
metadata = self._recv_metadata()
|
||||
if metadata["dtype"] is None:
|
||||
return None
|
||||
buffer = self._prepare_recv_buffer(metadata)
|
||||
self.device_recv_func(buffer, self.target_rank_for_recv)
|
||||
|
||||
return buffer
|
||||
|
||||
def send_tensor_wrapper(
|
||||
self, tensor: torch.Tensor | None, tensor_size: int
|
||||
) -> None:
|
||||
"""
|
||||
Wrapper for _send_impl to handle exceptions and update buffer size.
|
||||
"""
|
||||
try:
|
||||
self._send_impl(tensor)
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size -= tensor_size
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[rank%d]: Exception when trying to send %s, msg: %s",
|
||||
torch.distributed.get_rank(),
|
||||
str(tensor),
|
||||
str(e),
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def block_if_full(self):
|
||||
"""
|
||||
Block the current thread if the buffer size is larger than the
|
||||
threshold.
|
||||
"""
|
||||
while self.buffer_size > self.buffer_size_thresh:
|
||||
logger.debug("KV cache transfer pipe is full. Waiting...")
|
||||
time.sleep(0.05)
|
||||
|
||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
||||
"""
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
non-blocking way.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to send, or `None` if no tensor is being sent.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
if tensor is not None:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
tensor_size = 0
|
||||
|
||||
self.block_if_full()
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size)
|
||||
|
||||
def recv_tensor(self) -> torch.Tensor | None:
|
||||
"""
|
||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
||||
|
||||
Returns:
|
||||
The received tensor, or `None` if no tensor is received.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
future = self.transport_thread.submit(self._recv_impl)
|
||||
|
||||
try:
|
||||
tensor = future.result()
|
||||
except Exception as e:
|
||||
logger.error("Encountering exception in KV receiving thread")
|
||||
logger.error("%s", e)
|
||||
logger.error("My device: %s", self.device)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
return tensor
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the pipe and release associated resources.
|
||||
"""
|
||||
if hasattr(self, "transport_thread") and self.transport_thread is not None:
|
||||
self.transport_thread.shutdown()
|
||||
Reference in New Issue
Block a user