mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[P/D] Introduce Mooncake Transfer Engine as kv_connector (#24718)
Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com> Signed-off-by: dtc <dtcccc@linux.alibaba.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
58
docs/features/mooncake_connector_usage.md
Normal file
58
docs/features/mooncake_connector_usage.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# MooncakeConnector Usage Guide
|
||||
|
||||
## About Mooncake
|
||||
|
||||
Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine.
|
||||
|
||||
For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Installation
|
||||
|
||||
Install mooncake through pip: `uv pip install mooncake-transfer-engine`.
|
||||
|
||||
Refer to [Mooncake official repository](https://github.com/kvcache-ai/Mooncake) for more installation instructions
|
||||
|
||||
## Usage
|
||||
|
||||
### Prefiller Node (192.168.0.2)
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8010 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}'
|
||||
```
|
||||
|
||||
### Decoder Node (192.168.0.3)
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}'
|
||||
```
|
||||
|
||||
### Proxy
|
||||
|
||||
```bash
|
||||
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020
|
||||
```
|
||||
|
||||
> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future.
|
||||
|
||||
Now you can send requests to the proxy server through port 8000.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server
|
||||
- Default: 8998
|
||||
- Required only for prefiller instances
|
||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank
|
||||
- Used for the decoder notifying the prefiller
|
||||
|
||||
- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
|
||||
- Default: 480
|
||||
- If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
|
||||
|
||||
## KV Role Options
|
||||
|
||||
- **kv_producer**: For prefiller instances that generate KV caches
|
||||
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
|
||||
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
|
||||
@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
|
||||
"MooncakeConnector",
|
||||
)
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
@@ -181,3 +184,124 @@ def copy_kv_blocks(
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[str, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: str
|
||||
remote_block_size: dict[str, int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> float:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: str) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
|
||||
@@ -0,0 +1,914 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run VLLM with MooncakeTransferEngine."
|
||||
) from e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
TRANS_DONE = b"trans_done"
|
||||
TRANS_ERROR = b"trans_error"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
):
|
||||
remote_hostname: str
|
||||
remote_port: int
|
||||
request_ids: list[ReqId]
|
||||
kv_caches_base_addr: list[int]
|
||||
block_ids: list[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecvReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendBlockMeta:
|
||||
local_block_ids: list[int]
|
||||
ready: threading.Event
|
||||
expire_time: float = float("inf")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendReqMeta:
|
||||
reqs: dict[ReqId, SendBlockMeta]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedSendReqSet:
|
||||
set: set[ReqId]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedReceiveReqSet:
|
||||
set: set[ReqId]
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, list[int]] = {}
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
):
|
||||
if load_remote_cache:
|
||||
self.reqs_to_recv[request_id] = RecvReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
)
|
||||
else:
|
||||
self.reqs_to_send[request_id] = local_block_ids
|
||||
|
||||
|
||||
class MooncakeConnector(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: MooncakeConnectorScheduler | None = (
|
||||
MooncakeConnectorScheduler(vllm_config, self.engine_id)
|
||||
)
|
||||
self.connector_worker: MooncakeConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
class MooncakeConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.side_channel_host = get_ip()
|
||||
self.side_channel_port = get_mooncake_side_channel_port(vllm_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, list[int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector get_num_new_matched_tokens: "
|
||||
"num_computed_tokens=%s, kv_transfer_params=%s",
|
||||
num_computed_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
token_ids = request.prompt_token_ids or []
|
||||
count = len(token_ids) - num_computed_tokens
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector update_state_after_alloc: "
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
assert self.kv_role != "kv_producer"
|
||||
if all(p in params for p in ("remote_host", "remote_port")):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
elif params.get("do_remote_decode"):
|
||||
# Add an empty list to worker to create event.
|
||||
self._reqs_need_send[request.request_id] = []
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = MooncakeConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to RecvReqMeta.
|
||||
if self.kv_role != "kv_producer":
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
for req_id, block_ids in self._reqs_need_send.items():
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params={},
|
||||
load_remote_cache=False,
|
||||
)
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s",
|
||||
request.status,
|
||||
params,
|
||||
)
|
||||
if not params:
|
||||
return False, None
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
# If do_remote_prefill is still True when the request is finished,
|
||||
# update_state_after_alloc must not have been called (the request
|
||||
# must have been aborted before it was scheduled).
|
||||
# To avoid stranding the prefill blocks in the prefill instance,
|
||||
# we must add empty block_ids to _reqs_need_recv so that our
|
||||
# worker side will notify and free blocks in the prefill instance.
|
||||
assert self.kv_role != "kv_producer"
|
||||
self._reqs_need_recv[request.request_id] = (request, [])
|
||||
params["do_remote_prefill"] = False
|
||||
return False, None
|
||||
|
||||
if (
|
||||
not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
):
|
||||
return False, None
|
||||
|
||||
assert self.kv_role != "kv_consumer"
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
# remove the conditional below
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
|
||||
if delay_free_blocks:
|
||||
self._reqs_need_send[request.request_id] = block_ids
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = get_ip()
|
||||
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||
|
||||
self.rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
logger.debug(
|
||||
"Mooncake Transfer Engine initialized at %s:%d",
|
||||
self.hostname,
|
||||
self.rpc_port,
|
||||
)
|
||||
|
||||
# Mooncake handshake port.
|
||||
self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config)
|
||||
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
self.num_blocks = 0
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"num_workers", 10
|
||||
)
|
||||
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
|
||||
|
||||
# For kv_both, we will act both prefiller and decoder.
|
||||
if self.kv_role != "kv_consumer":
|
||||
# Background thread for sending kvcaches to D.
|
||||
self._mooncake_sender_t: threading.Thread | None = None
|
||||
# Background thread for processing new sending requests.
|
||||
self._sender_executor = ThreadPoolExecutor(
|
||||
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
|
||||
)
|
||||
logger.debug(
|
||||
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
|
||||
)
|
||||
if self.kv_role != "kv_producer":
|
||||
self.receiver_loop = asyncio.new_event_loop()
|
||||
self._mooncake_receiver_t = threading.Thread(
|
||||
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
|
||||
)
|
||||
self._mooncake_receiver_t.start()
|
||||
logger.debug("Mooncake Decoder: start receiver thread")
|
||||
|
||||
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
|
||||
set(), threading.Lock()
|
||||
)
|
||||
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
|
||||
set(), asyncio.Lock()
|
||||
)
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
self.backend_name = backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
||||
self.kv_topo = TpKVTopology(
|
||||
tp_rank=self.tp_rank,
|
||||
engine_id=self.engine_id,
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
remote_block_size=self._block_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
|
||||
self.zmq_ctx = zmq.Context()
|
||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||
self._encoder = msgspec.msgpack.Encoder()
|
||||
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self.zmq_ctx.term()
|
||||
self.async_zmq_ctx.term()
|
||||
if self.kv_role != "kv_consumer":
|
||||
self._sender_executor.shutdown(wait=False)
|
||||
if self._mooncake_sender_t:
|
||||
self._mooncake_sender_t.join()
|
||||
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
|
||||
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
||||
self._mooncake_receiver_t.join()
|
||||
|
||||
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
def _mooncake_sender(
|
||||
self, ready_event: threading.Event, base_port: int, tp_rank: int
|
||||
):
|
||||
"""
|
||||
Background thread that listens for Mooncake requests, dispatches them
|
||||
to a thread pool, and sends acknowledgments upon completion.
|
||||
"""
|
||||
|
||||
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
||||
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
|
||||
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
|
||||
|
||||
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
|
||||
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(frontend, zmq.POLLIN)
|
||||
poller.register(backend, zmq.POLLIN)
|
||||
|
||||
ready_event.set()
|
||||
|
||||
try:
|
||||
while True:
|
||||
sockets = dict(poller.poll())
|
||||
|
||||
if frontend in sockets:
|
||||
identity, _, metadata_bytes = frontend.recv_multipart()
|
||||
self._sender_executor.submit(
|
||||
self._sender_worker,
|
||||
identity,
|
||||
metadata_bytes,
|
||||
backend_path,
|
||||
)
|
||||
|
||||
if backend in sockets:
|
||||
identity, status = backend.recv_multipart()
|
||||
frontend.send_multipart((identity, b"", status))
|
||||
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
||||
except Exception as e:
|
||||
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
||||
finally:
|
||||
frontend.close()
|
||||
backend.close()
|
||||
|
||||
def _sender_worker(
|
||||
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
|
||||
):
|
||||
status = TRANS_ERROR
|
||||
|
||||
try:
|
||||
metadata = self._decoder.decode(metadata_bytes)
|
||||
self.send_kv_to_decode(metadata)
|
||||
status = TRANS_DONE
|
||||
except Exception as e:
|
||||
logger.error("Error processing Mooncake handshake: %s", e)
|
||||
finally:
|
||||
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
|
||||
try:
|
||||
pusher.send_multipart((identity, status))
|
||||
except zmq.ZMQError as e:
|
||||
logger.warning(
|
||||
"Internal error, maybe the server is shutting down. Error: %s",
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
pusher.close()
|
||||
|
||||
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
send_meta = self.reqs_need_send.reqs.get(req_id)
|
||||
if send_meta is None:
|
||||
logger.warning("Request %s not found in reqs_need_send", req_id)
|
||||
return
|
||||
# Mark it as not expired. We will send it now.
|
||||
send_meta.expire_time = float("inf")
|
||||
send_reqs.append((req_id, send_meta))
|
||||
|
||||
self._send_blocks(send_reqs, meta)
|
||||
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
|
||||
with self.finished_sending_reqs.lock:
|
||||
self.finished_sending_reqs.set.update(meta.request_ids)
|
||||
|
||||
def _send_blocks(
|
||||
self,
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]],
|
||||
agent_meta: MooncakeAgentMetadata,
|
||||
):
|
||||
src_ptrs = []
|
||||
dst_ptrs = []
|
||||
lengths = []
|
||||
local_base_addr = self.kv_caches_base_addr
|
||||
remote_base_addr = agent_meta.kv_caches_base_addr
|
||||
block_len = self.block_len
|
||||
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
|
||||
|
||||
assert len(send_reqs) == len(agent_meta.block_ids)
|
||||
for (req_id, send_meta), remote_block_ids in zip(
|
||||
send_reqs, agent_meta.block_ids
|
||||
):
|
||||
send_meta.ready.wait()
|
||||
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
if num_remote_blocks == 0:
|
||||
continue
|
||||
|
||||
local_block_ids = send_meta.local_block_ids
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
assert num_local_blocks >= num_remote_blocks
|
||||
if num_local_blocks > num_remote_blocks:
|
||||
local_block_ids = local_block_ids[-num_remote_blocks:]
|
||||
|
||||
# Group by indices
|
||||
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids
|
||||
)
|
||||
|
||||
for local_layer_addr, remote_layer_addr in zip(
|
||||
local_base_addr, remote_base_addr
|
||||
):
|
||||
for group_local_block_id, group_remote_block_id in zip(
|
||||
group_local_block_ids, group_remote_block_ids
|
||||
):
|
||||
src_ptrs.append(
|
||||
local_layer_addr + group_local_block_id[0] * block_len
|
||||
)
|
||||
dst_ptrs.append(
|
||||
remote_layer_addr + group_remote_block_id[0] * block_len
|
||||
)
|
||||
lengths.append(block_len * len(group_local_block_id))
|
||||
|
||||
logger.debug(
|
||||
"Sending kv_caches for request %s (%d blocks) to %s",
|
||||
req_id,
|
||||
num_remote_blocks,
|
||||
remote_session,
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
ret_value = self.engine.batch_transfer_sync_write(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
||||
|
||||
logger.debug(
|
||||
"Sending to %s done, took %s",
|
||||
remote_session,
|
||||
time.perf_counter() - start_time,
|
||||
)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in mooncake."""
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
|
||||
|
||||
kv_data_ptrs = []
|
||||
kv_data_lens = []
|
||||
seen_base_addresses = []
|
||||
|
||||
split_k_and_v = self.kv_topo.split_k_and_v
|
||||
tensor_size_bytes = None
|
||||
for layer_name, cache_or_caches in kv_caches.items():
|
||||
logger.debug(
|
||||
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
|
||||
)
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.nbytes
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
||||
assert self.block_size == kernel_block_size
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(tensor_size_bytes)
|
||||
|
||||
self.kv_caches_base_addr = seen_base_addresses
|
||||
|
||||
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake batch memory registration failed.")
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.device_kv_caches = kv_caches
|
||||
logger.debug(
|
||||
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
||||
)
|
||||
|
||||
# No need to launch server for D node.
|
||||
if self.kv_role == "kv_consumer":
|
||||
return
|
||||
|
||||
ready_event = threading.Event()
|
||||
self._mooncake_sender_t = threading.Thread(
|
||||
target=self._mooncake_sender,
|
||||
args=(ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="mooncake_sender",
|
||||
)
|
||||
self._mooncake_sender_t.start()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||
async with self.finished_recving_reqs.lock:
|
||||
finished_recving_reqs = self.finished_recving_reqs.set
|
||||
self.finished_recving_reqs.set = set()
|
||||
return finished_recving_reqs
|
||||
|
||||
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
to track which workers are done.
|
||||
"""
|
||||
fut = None
|
||||
if self.kv_role != "kv_producer":
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self.fetch_finished_recving_reqs(), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.finished_sending_reqs.lock:
|
||||
finished_sending_reqs = self.finished_sending_reqs.set
|
||||
self.finished_sending_reqs.set = set()
|
||||
else:
|
||||
finished_sending_reqs = set()
|
||||
|
||||
finished_recving_reqs = fut.result() if fut else set()
|
||||
|
||||
if finished_sending_reqs or finished_recving_reqs:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
"and %s requests done recving",
|
||||
self.tp_rank,
|
||||
len(finished_sending_reqs),
|
||||
len(finished_recving_reqs),
|
||||
)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
with self.reqs_need_send.lock:
|
||||
expired_reqs = [
|
||||
req_id
|
||||
for req_id, send_meta in self.reqs_need_send.reqs.items()
|
||||
if send_meta.expire_time < now
|
||||
]
|
||||
for req_id in expired_reqs:
|
||||
logger.warning(
|
||||
"Request %s timed out after %d seconds without "
|
||||
"being sent. Freeing its blocks on the producer side.",
|
||||
req_id,
|
||||
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
if expired_reqs:
|
||||
finished_sending_reqs.update(expired_reqs)
|
||||
|
||||
return finished_sending_reqs or None, finished_recving_reqs or None
|
||||
|
||||
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
|
||||
req_ids, block_ids = map(list, zip(*req_blocks))
|
||||
metadata = MooncakeAgentMetadata(
|
||||
remote_hostname=self.hostname,
|
||||
remote_port=self.rpc_port,
|
||||
request_ids=req_ids,
|
||||
kv_caches_base_addr=self.kv_caches_base_addr,
|
||||
block_ids=block_ids,
|
||||
)
|
||||
|
||||
encoded_data = self._encoder.encode(metadata)
|
||||
logger.debug(
|
||||
"Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data)
|
||||
)
|
||||
logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path)
|
||||
|
||||
# Send query for the request.
|
||||
sock: zmq.asyncio.Socket = make_zmq_socket(
|
||||
self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0
|
||||
)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 60000)
|
||||
try:
|
||||
await sock.send(encoded_data)
|
||||
ret_msg = await sock.recv()
|
||||
if ret_msg != TRANS_DONE:
|
||||
logger.error(
|
||||
"Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501
|
||||
req_ids,
|
||||
)
|
||||
return
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
|
||||
except Exception as e:
|
||||
logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e)
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async with self.finished_recving_reqs.lock:
|
||||
self.finished_recving_reqs.set.update(req_ids)
|
||||
|
||||
logger.debug("pulling kv_caches for %s finished", req_ids)
|
||||
|
||||
def group_kv_pull(self, metadata: MooncakeConnectorMetadata):
|
||||
kv_pulls = defaultdict(list)
|
||||
for req_id, meta in metadata.reqs_to_recv.items():
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine. "
|
||||
"Num local_block_ids: %s.",
|
||||
req_id,
|
||||
len(meta.local_block_ids),
|
||||
)
|
||||
path = make_zmq_path(
|
||||
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
|
||||
)
|
||||
kv_pulls[path].append((req_id, meta.local_block_ids))
|
||||
|
||||
return kv_pulls
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
if self.kv_role != "kv_producer":
|
||||
kv_pulls = self.group_kv_pull(metadata)
|
||||
for path, req_blocks in kv_pulls.items():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_kv(path, req_blocks), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id, block_ids in metadata.reqs_to_send.items():
|
||||
if block_ids:
|
||||
# Already gone through request_finished()
|
||||
send_meta = self.reqs_need_send.reqs[req_id]
|
||||
send_meta.local_block_ids = block_ids
|
||||
send_meta.ready.set()
|
||||
send_meta.expire_time = (
|
||||
time.perf_counter()
|
||||
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
else:
|
||||
# From update_state_after_alloc(),
|
||||
# but not reach request_finished() yet
|
||||
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
|
||||
local_block_ids=[], ready=threading.Event()
|
||||
)
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: list[int], dst_indices: list[int]
|
||||
) -> tuple[list[list[int]], list[list[int]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if len(src_indices) == 0:
|
||||
return [], []
|
||||
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
|
||||
return src_groups, dst_groups
|
||||
|
||||
|
||||
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
||||
# This logic is now centralized
|
||||
return (
|
||||
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
@@ -20,10 +20,10 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
@@ -668,128 +668,6 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: EngineId
|
||||
remote_block_size: dict[EngineId, int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (
|
||||
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
|
||||
)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> float:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
@@ -958,7 +836,7 @@ class NixlConnectorWorker:
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
self.kv_topo = self.TpKVTopology(
|
||||
self.kv_topo = TpKVTopology(
|
||||
tp_rank=self.tp_rank,
|
||||
engine_id=self.engine_id,
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
|
||||
10
vllm/envs.py
10
vllm/envs.py
@@ -175,6 +175,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600
|
||||
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
|
||||
VLLM_ALL2ALL_BACKEND: Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
@@ -197,6 +198,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
||||
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
@@ -1260,6 +1262,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(
|
||||
os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600")
|
||||
),
|
||||
# Port used for Mooncake handshake between remote agents.
|
||||
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
|
||||
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
|
||||
),
|
||||
# all2all backend for vllm's expert parallel communication
|
||||
# Available options:
|
||||
# - "naive": naive all2all implementation using broadcasts
|
||||
@@ -1369,6 +1375,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(
|
||||
os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")
|
||||
),
|
||||
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
|
||||
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
|
||||
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
|
||||
),
|
||||
# Controls whether or not to use cudnn prefill
|
||||
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
|
||||
|
||||
Reference in New Issue
Block a user