[P/D] Introduce Mooncake Transfer Engine as kv_connector (#24718)

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Signed-off-by: dtc <dtcccc@linux.alibaba.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
dtc
2025-12-04 17:51:36 +08:00
committed by GitHub
parent f2f4cea6cc
commit 842aba501d
6 changed files with 1114 additions and 125 deletions

View File

@@ -0,0 +1,58 @@
# MooncakeConnector Usage Guide
## About Mooncake
Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine.
For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/).
## Prerequisites
### Installation
Install mooncake through pip: `uv pip install mooncake-transfer-engine`.
Refer to [Mooncake official repository](https://github.com/kvcache-ai/Mooncake) for more installation instructions
## Usage
### Prefiller Node (192.168.0.2)
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8010 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}'
```
### Decoder Node (192.168.0.3)
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}'
```
### Proxy
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020
```
> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future.
Now you can send requests to the proxy server through port 8000.
## Environment Variables
- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server
- Default: 8998
- Required only for prefiller instances
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank
- Used for the decoder notifying the prefiller
- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefillers KV cache for a particular request. (Optional)
- Default: 480
- If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
## KV Role Options
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.

View File

@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
"MooncakeConnector",
)

View File

@@ -4,10 +4,13 @@
KV cache helper for store.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
@@ -181,3 +184,124 @@ def copy_kv_blocks(
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[str, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)

View File

@@ -0,0 +1,914 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading
import time
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import numpy as np
import torch
import zmq
import zmq.asyncio
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run VLLM with MooncakeTransferEngine."
) from e
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
EngineId = str
ReqId = str
TRANS_DONE = b"trans_done"
TRANS_ERROR = b"trans_error"
logger = init_logger(__name__)
class MooncakeAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
remote_hostname: str
remote_port: int
request_ids: list[ReqId]
kv_caches_base_addr: list[int]
block_ids: list[list[int]]
@dataclass
class RecvReqMeta:
local_block_ids: list[int]
remote_host: str
remote_port: int
@dataclass
class SendBlockMeta:
local_block_ids: list[int]
ready: threading.Event
expire_time: float = float("inf")
@dataclass
class SendReqMeta:
reqs: dict[ReqId, SendBlockMeta]
lock: threading.Lock
@dataclass
class FinishedSendReqSet:
set: set[ReqId]
lock: threading.Lock
@dataclass
class FinishedReceiveReqSet:
set: set[ReqId]
lock: asyncio.Lock
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
self.reqs_to_send: dict[ReqId, list[int]] = {}
def add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
):
if load_remote_cache:
self.reqs_to_recv[request_id] = RecvReqMeta(
local_block_ids=local_block_ids,
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
)
else:
self.reqs_to_send[request_id] = local_block_ids
class MooncakeConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: MooncakeConnectorScheduler | None = (
MooncakeConnectorScheduler(vllm_config, self.engine_id)
)
self.connector_worker: MooncakeConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""MooncakeConnector does not save explicitly."""
pass
def wait_for_save(self):
pass
class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.engine_id: EngineId = engine_id
self.side_channel_host = get_ip()
self.side_channel_port = get_mooncake_side_channel_port(vllm_config)
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[ReqId, list[int]] = {}
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens,
params,
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0:
return count, True
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_prefill"):
assert self.kv_role != "kv_producer"
if all(p in params for p in ("remote_host", "remote_port")):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
elif params.get("do_remote_decode"):
# Add an empty list to worker to create event.
self._reqs_need_send[request.request_id] = []
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = MooncakeConnectorMetadata()
# Loop through scheduled reqs and convert to RecvReqMeta.
if self.kv_role != "kv_producer":
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
self._reqs_need_recv.clear()
if self.kv_role != "kv_consumer":
for req_id, block_ids in self._reqs_need_send.items():
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params={},
load_remote_cache=False,
)
self._reqs_need_send.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector request_finished, request_status=%s, "
"kv_transfer_params=%s",
request.status,
params,
)
if not params:
return False, None
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
assert self.kv_role != "kv_producer"
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None
if (
not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED
):
return False, None
assert self.kv_role != "kv_consumer"
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
self._reqs_need_send[request.request_id] = block_ids
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
self.vllm_config = vllm_config
self.engine = TransferEngine()
self.hostname = get_ip()
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
if ret_value != 0:
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
self.rpc_port = self.engine.get_rpc_port()
logger.debug(
"Mooncake Transfer Engine initialized at %s:%d",
self.hostname,
self.rpc_port,
)
# Mooncake handshake port.
self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config)
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"num_workers", 10
)
self.kv_caches_base_addr: list[int] = []
self.device_kv_caches: dict[str, torch.Tensor] = {}
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
# For kv_both, we will act both prefiller and decoder.
if self.kv_role != "kv_consumer":
# Background thread for sending kvcaches to D.
self._mooncake_sender_t: threading.Thread | None = None
# Background thread for processing new sending requests.
self._sender_executor = ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
)
logger.debug(
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
)
if self.kv_role != "kv_producer":
self.receiver_loop = asyncio.new_event_loop()
self._mooncake_receiver_t = threading.Thread(
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
)
self._mooncake_receiver_t.start()
logger.debug("Mooncake Decoder: start receiver thread")
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
set(), threading.Lock()
)
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
set(), asyncio.Lock()
)
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.use_mla = self.model_config.use_mla
backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla,
)
self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context()
self._encoder = msgspec.msgpack.Encoder()
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Cleanup background threads on destruction."""
self.zmq_ctx.term()
self.async_zmq_ctx.term()
if self.kv_role != "kv_consumer":
self._sender_executor.shutdown(wait=False)
if self._mooncake_sender_t:
self._mooncake_sender_t.join()
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
self._mooncake_receiver_t.join()
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()
def _mooncake_sender(
self, ready_event: threading.Event, base_port: int, tp_rank: int
):
"""
Background thread that listens for Mooncake requests, dispatches them
to a thread pool, and sends acknowledgments upon completion.
"""
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
poller = zmq.Poller()
poller.register(frontend, zmq.POLLIN)
poller.register(backend, zmq.POLLIN)
ready_event.set()
try:
while True:
sockets = dict(poller.poll())
if frontend in sockets:
identity, _, metadata_bytes = frontend.recv_multipart()
self._sender_executor.submit(
self._sender_worker,
identity,
metadata_bytes,
backend_path,
)
if backend in sockets:
identity, status = backend.recv_multipart()
frontend.send_multipart((identity, b"", status))
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
except Exception as e:
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
finally:
frontend.close()
backend.close()
def _sender_worker(
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
):
status = TRANS_ERROR
try:
metadata = self._decoder.decode(metadata_bytes)
self.send_kv_to_decode(metadata)
status = TRANS_DONE
except Exception as e:
logger.error("Error processing Mooncake handshake: %s", e)
finally:
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
try:
pusher.send_multipart((identity, status))
except zmq.ZMQError as e:
logger.warning(
"Internal error, maybe the server is shutting down. Error: %s",
e,
)
finally:
pusher.close()
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
send_meta = self.reqs_need_send.reqs.get(req_id)
if send_meta is None:
logger.warning("Request %s not found in reqs_need_send", req_id)
return
# Mark it as not expired. We will send it now.
send_meta.expire_time = float("inf")
send_reqs.append((req_id, send_meta))
self._send_blocks(send_reqs, meta)
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
del self.reqs_need_send.reqs[req_id]
with self.finished_sending_reqs.lock:
self.finished_sending_reqs.set.update(meta.request_ids)
def _send_blocks(
self,
send_reqs: list[tuple[ReqId, SendBlockMeta]],
agent_meta: MooncakeAgentMetadata,
):
src_ptrs = []
dst_ptrs = []
lengths = []
local_base_addr = self.kv_caches_base_addr
remote_base_addr = agent_meta.kv_caches_base_addr
block_len = self.block_len
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
assert len(send_reqs) == len(agent_meta.block_ids)
for (req_id, send_meta), remote_block_ids in zip(
send_reqs, agent_meta.block_ids
):
send_meta.ready.wait()
num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0:
continue
local_block_ids = send_meta.local_block_ids
# Partial prefix cache hit: just read uncomputed blocks.
num_local_blocks = len(local_block_ids)
assert num_local_blocks >= num_remote_blocks
if num_local_blocks > num_remote_blocks:
local_block_ids = local_block_ids[-num_remote_blocks:]
# Group by indices
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
local_block_ids, remote_block_ids
)
for local_layer_addr, remote_layer_addr in zip(
local_base_addr, remote_base_addr
):
for group_local_block_id, group_remote_block_id in zip(
group_local_block_ids, group_remote_block_ids
):
src_ptrs.append(
local_layer_addr + group_local_block_id[0] * block_len
)
dst_ptrs.append(
remote_layer_addr + group_remote_block_id[0] * block_len
)
lengths.append(block_len * len(group_local_block_id))
logger.debug(
"Sending kv_caches for request %s (%d blocks) to %s",
req_id,
num_remote_blocks,
remote_session,
)
start_time = time.perf_counter()
ret_value = self.engine.batch_transfer_sync_write(
remote_session, src_ptrs, dst_ptrs, lengths
)
if ret_value != 0:
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
logger.debug(
"Sending to %s done, took %s",
remote_session,
time.perf_counter() - start_time,
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in mooncake."""
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
kv_data_ptrs = []
kv_data_lens = []
seen_base_addresses = []
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items():
logger.debug(
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
)
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.nbytes
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes)
self.kv_caches_base_addr = seen_base_addresses
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
if ret_value != 0:
raise RuntimeError("Mooncake batch memory registration failed.")
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.device_kv_caches = kv_caches
logger.debug(
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
)
# No need to launch server for D node.
if self.kv_role == "kv_consumer":
return
ready_event = threading.Event()
self._mooncake_sender_t = threading.Thread(
target=self._mooncake_sender,
args=(ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="mooncake_sender",
)
self._mooncake_sender_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
async with self.finished_recving_reqs.lock:
finished_recving_reqs = self.finished_recving_reqs.set
self.finished_recving_reqs.set = set()
return finished_recving_reqs
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
fut = None
if self.kv_role != "kv_producer":
fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_recving_reqs(), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.finished_sending_reqs.lock:
finished_sending_reqs = self.finished_sending_reqs.set
self.finished_sending_reqs.set = set()
else:
finished_sending_reqs = set()
finished_recving_reqs = fut.result() if fut else set()
if finished_sending_reqs or finished_recving_reqs:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving",
self.tp_rank,
len(finished_sending_reqs),
len(finished_recving_reqs),
)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
with self.reqs_need_send.lock:
expired_reqs = [
req_id
for req_id, send_meta in self.reqs_need_send.reqs.items()
if send_meta.expire_time < now
]
for req_id in expired_reqs:
logger.warning(
"Request %s timed out after %d seconds without "
"being sent. Freeing its blocks on the producer side.",
req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
)
del self.reqs_need_send.reqs[req_id]
if expired_reqs:
finished_sending_reqs.update(expired_reqs)
return finished_sending_reqs or None, finished_recving_reqs or None
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
req_ids, block_ids = map(list, zip(*req_blocks))
metadata = MooncakeAgentMetadata(
remote_hostname=self.hostname,
remote_port=self.rpc_port,
request_ids=req_ids,
kv_caches_base_addr=self.kv_caches_base_addr,
block_ids=block_ids,
)
encoded_data = self._encoder.encode(metadata)
logger.debug(
"Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data)
)
logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path)
# Send query for the request.
sock: zmq.asyncio.Socket = make_zmq_socket(
self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0
)
sock.setsockopt(zmq.RCVTIMEO, 60000)
try:
await sock.send(encoded_data)
ret_msg = await sock.recv()
if ret_msg != TRANS_DONE:
logger.error(
"Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501
req_ids,
)
return
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
except Exception as e:
logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e)
return
finally:
sock.close()
async with self.finished_recving_reqs.lock:
self.finished_recving_reqs.set.update(req_ids)
logger.debug("pulling kv_caches for %s finished", req_ids)
def group_kv_pull(self, metadata: MooncakeConnectorMetadata):
kv_pulls = defaultdict(list)
for req_id, meta in metadata.reqs_to_recv.items():
logger.debug(
"start_load_kv for request %s from remote engine. "
"Num local_block_ids: %s.",
req_id,
len(meta.local_block_ids),
)
path = make_zmq_path(
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
)
kv_pulls[path].append((req_id, meta.local_block_ids))
return kv_pulls
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if self.kv_role != "kv_producer":
kv_pulls = self.group_kv_pull(metadata)
for path, req_blocks in kv_pulls.items():
asyncio.run_coroutine_threadsafe(
self.receive_kv(path, req_blocks), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.reqs_need_send.lock:
for req_id, block_ids in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send.reqs[req_id]
send_meta.local_block_ids = block_ids
send_meta.ready.set()
send_meta.expire_time = (
time.perf_counter()
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
local_block_ids=[], ready=threading.Event()
)
def group_concurrent_contiguous(
src_indices: list[int], dst_indices: list[int]
) -> tuple[list[list[int]], list[list[int]]]:
"""Vectorised NumPy implementation."""
if len(src_indices) == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
# This logic is now centralized
return (
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)

View File

@@ -20,10 +20,10 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
@@ -668,128 +668,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
@@ -958,7 +836,7 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = self.TpKVTopology(
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state

View File

@@ -175,6 +175,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
VLLM_ALL2ALL_BACKEND: Literal[
"naive",
"pplx",
@@ -197,6 +198,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
@@ -1260,6 +1262,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(
os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600")
),
# Port used for Mooncake handshake between remote agents.
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
),
# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using broadcasts
@@ -1369,6 +1375,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")
),
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
),
# Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))