394 lines
14 KiB
Python
394 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import asyncio
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Awaitable, Callable
|
|
from functools import cached_property
|
|
from typing import Any
|
|
|
|
from typing_extensions import TypeVar
|
|
|
|
import vllm.platforms
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.sequence import ExecuteModelRequest
|
|
from vllm.tasks import SupportedTask
|
|
from vllm.utils import make_async
|
|
from vllm.v1.outputs import SamplerOutput
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_R = TypeVar("_R", default=Any)
|
|
|
|
|
|
class ExecutorBase(ABC):
|
|
"""Base class for all executors.
|
|
|
|
An executor is responsible for executing the model on one device,
|
|
or it can be a distributed executor
|
|
that can execute the model on multiple devices.
|
|
"""
|
|
|
|
uses_ray: bool # whether the executor uses Ray for orchestration.
|
|
supports_pp: bool = False # whether the executor supports PP
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
) -> None:
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.load_config = vllm_config.load_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.device_config = vllm_config.device_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.observability_config = vllm_config.observability_config
|
|
self._init_executor()
|
|
self.is_sleeping = False
|
|
self.sleeping_tags: set[str] = set()
|
|
self.kv_output_aggregator: KVOutputAggregator | None = None
|
|
|
|
@abstractmethod
|
|
def _init_executor(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def collective_rpc(
|
|
self,
|
|
method: str | Callable[[WorkerBase], _R],
|
|
timeout: float | None = None,
|
|
args: tuple = (),
|
|
kwargs: dict[str, Any] | None = None,
|
|
) -> list[_R]:
|
|
"""
|
|
Execute an RPC call on all workers.
|
|
|
|
Args:
|
|
method: Name of the worker method to execute, or a callable that
|
|
is serialized and sent to all workers to execute.
|
|
|
|
If the method is a callable, it should accept an additional
|
|
`self` argument, in addition to the arguments passed in `args`
|
|
and `kwargs`. The `self` argument will be the worker object.
|
|
timeout: Maximum time in seconds to wait for execution. Raises a
|
|
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
|
|
args: Positional arguments to pass to the worker method.
|
|
kwargs: Keyword arguments to pass to the worker method.
|
|
|
|
Returns:
|
|
A list containing the results from each worker.
|
|
|
|
Note:
|
|
It is recommended to use this API to only pass control messages,
|
|
and set up data-plane communication to pass data.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
|
"""Determine the number of available blocks for the GPU KV cache and
|
|
swappable CPU KV cache.
|
|
|
|
Normally, this should simply delegate to the underlying Worker. Some
|
|
ExecutorBase may require modification of the result, e.g. to ensure the
|
|
selected cache sizes are compatible with all workers.
|
|
|
|
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
|
|
`num_gpu_blocks` are blocks that are "active" on the device and can be
|
|
appended to.
|
|
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
|
|
appended to.
|
|
"""
|
|
results = self.collective_rpc("determine_num_available_blocks")
|
|
a = min([r[0] for r in results])
|
|
b = min([r[1] for r in results])
|
|
return a, b
|
|
|
|
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
|
"""Initialize the KV cache by invoking the underlying worker."""
|
|
# NOTE: This is logged in the executor because there can be >1 workers.
|
|
logger.info(
|
|
"# %s blocks: %d, # CPU blocks: %d",
|
|
vllm.platforms.current_platform.device_name,
|
|
num_gpu_blocks,
|
|
num_cpu_blocks,
|
|
)
|
|
max_concurrency = (
|
|
num_gpu_blocks
|
|
* self.cache_config.block_size
|
|
/ self.model_config.max_model_len
|
|
)
|
|
logger.info(
|
|
"Maximum concurrency for %s tokens per request: %.2fx",
|
|
self.model_config.max_model_len,
|
|
max_concurrency,
|
|
)
|
|
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
|
|
|
|
@cached_property # Avoid unnecessary RPC calls
|
|
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
output = self.collective_rpc("get_supported_tasks")
|
|
return output[0]
|
|
|
|
def execute_model(
|
|
self, execute_model_req: ExecuteModelRequest
|
|
) -> list[SamplerOutput]:
|
|
output = self.collective_rpc("execute_model", args=(execute_model_req,))
|
|
assert output[0] is not None
|
|
return output[0]
|
|
|
|
def stop_remote_worker_execution_loop(self) -> None:
|
|
"""Releases parallel workers from model loop."""
|
|
return
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
|
return all(self.collective_rpc("add_lora", args=(lora_request,)))
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
assert lora_id > 0, "lora_id must be greater than 0."
|
|
return all(self.collective_rpc("remove_lora", args=(lora_id,)))
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
assert lora_id > 0, "lora_id must be greater than 0."
|
|
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
|
|
|
|
def list_loras(self) -> set[int]:
|
|
sets = self.collective_rpc("list_loras")
|
|
for s in sets:
|
|
assert s == sets[0], "All workers should have the same LORAs."
|
|
return sets[0]
|
|
|
|
def reset_mm_cache(self) -> None:
|
|
"""Reset the multi-modal cache in each worker."""
|
|
self.collective_rpc("reset_mm_cache")
|
|
|
|
def start_profile(self) -> None:
|
|
self.collective_rpc("start_profile")
|
|
|
|
def stop_profile(self) -> None:
|
|
self.collective_rpc("stop_profile")
|
|
|
|
def sleep(self, level: int = 1):
|
|
if self.is_sleeping:
|
|
logger.warning("Executor is already sleeping.")
|
|
return
|
|
time_before_sleep = time.perf_counter()
|
|
self.collective_rpc("sleep", kwargs=dict(level=level))
|
|
time_after_sleep = time.perf_counter()
|
|
self.sleeping_tags = {"weights", "kv_cache"}
|
|
self.is_sleeping = True
|
|
logger.info(
|
|
"It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
|
|
)
|
|
|
|
def wake_up(self, tags: list[str] | None = None):
|
|
if not self.is_sleeping:
|
|
logger.warning("Executor is not sleeping.")
|
|
return
|
|
if tags:
|
|
for tag in tags:
|
|
if tag not in self.sleeping_tags:
|
|
logger.warning(
|
|
"Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
|
|
)
|
|
return
|
|
time_before_wakeup = time.perf_counter()
|
|
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
|
|
time_after_wakeup = time.perf_counter()
|
|
logger.info(
|
|
"It took %.6f seconds to wake up tags %s.",
|
|
time_after_wakeup - time_before_wakeup,
|
|
tags if tags is not None else self.sleeping_tags,
|
|
)
|
|
if tags:
|
|
for tag in tags:
|
|
self.sleeping_tags.remove(tag)
|
|
else:
|
|
self.sleeping_tags.clear()
|
|
if not self.sleeping_tags:
|
|
self.is_sleeping = False
|
|
|
|
def save_sharded_state(
|
|
self,
|
|
path: str,
|
|
pattern: str | None = None,
|
|
max_size: int | None = None,
|
|
) -> None:
|
|
self.collective_rpc(
|
|
"save_sharded_state",
|
|
kwargs=dict(path=path, pattern=pattern, max_size=max_size),
|
|
)
|
|
|
|
@abstractmethod
|
|
def check_health(self) -> None:
|
|
"""Checks if the executor is healthy. If not, it should raise an
|
|
exception."""
|
|
raise NotImplementedError
|
|
|
|
def shutdown(self) -> None:
|
|
"""Shutdown the executor."""
|
|
self.collective_rpc("shutdown")
|
|
|
|
async def execute_model_async(
|
|
self, execute_model_req: ExecuteModelRequest
|
|
) -> list[SamplerOutput]:
|
|
"""Executes one model step on the given sequences."""
|
|
output = await make_async(self.execute_model)(execute_model_req)
|
|
return output
|
|
|
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
|
"""Releases parallel workers from model loop."""
|
|
return
|
|
|
|
async def check_health_async(self) -> None:
|
|
"""Checks if the executor is healthy. If not, it should raise an
|
|
exception."""
|
|
self.check_health()
|
|
|
|
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
|
|
"""Init KVOutputAggregator"""
|
|
self.kv_output_aggregator = KVOutputAggregator(
|
|
finished_count or self.parallel_config.world_size
|
|
)
|
|
|
|
|
|
class DistributedExecutorBase(ExecutorBase):
|
|
"""Abstract superclass of distributed executor implementations."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# This is non-None when the execute model loop is running
|
|
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
|
self.parallel_worker_tasks: Any | Awaitable[Any] | None = None
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def execute_model(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
) -> list[SamplerOutput]:
|
|
# TODO: unify into collective_rpc
|
|
if self.parallel_worker_tasks is None:
|
|
self.parallel_worker_tasks = self._run_workers(
|
|
"start_worker_execution_loop",
|
|
async_run_tensor_parallel_workers_only=True,
|
|
)
|
|
|
|
# Only the driver worker returns the sampling results.
|
|
driver_outputs = self._driver_execute_model(execute_model_req)
|
|
assert driver_outputs is not None
|
|
return driver_outputs
|
|
|
|
def stop_remote_worker_execution_loop(self) -> None:
|
|
if self.parallel_worker_tasks is None:
|
|
return
|
|
|
|
self._driver_execute_model(execute_model_req=None)
|
|
parallel_worker_tasks = self.parallel_worker_tasks
|
|
self.parallel_worker_tasks = None
|
|
# Ensure that workers exit model loop cleanly
|
|
# (this will raise otherwise)
|
|
self._wait_for_tasks_completion(parallel_worker_tasks)
|
|
|
|
@abstractmethod
|
|
def _driver_execute_model(
|
|
self, execute_model_req: ExecuteModelRequest | None
|
|
) -> list[SamplerOutput] | None:
|
|
"""Run execute_model in the driver worker.
|
|
|
|
Passing None will cause the driver to stop the model execution loop
|
|
running in each of the remote workers. In this case, this method
|
|
returns None. Otherwise, this method returns the model output.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def collective_rpc(
|
|
self,
|
|
method: str | Callable,
|
|
timeout: float | None = None,
|
|
args: tuple = (),
|
|
kwargs: dict[str, Any] | None = None,
|
|
) -> list[Any]:
|
|
return self._run_workers(method, *args, **(kwargs or {}))
|
|
|
|
@abstractmethod
|
|
def _run_workers(
|
|
self,
|
|
method: str | Callable,
|
|
*args,
|
|
async_run_tensor_parallel_workers_only: bool = False,
|
|
max_concurrent_workers: int | None = None,
|
|
**kwargs,
|
|
) -> Any:
|
|
"""Runs the given method on all workers.
|
|
|
|
Args:
|
|
async_run_tensor_parallel_workers_only: If True the method will be
|
|
run only in the remote TP workers, not the driver worker.
|
|
It will also be run asynchronously and return a list of futures
|
|
rather than blocking on the results.
|
|
|
|
# TODO: simplify and merge with collective_rpc
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
|
"""Wait for futures returned from _run_workers() with
|
|
async_run_remote_workers_only to complete."""
|
|
raise NotImplementedError
|
|
|
|
async def execute_model_async(
|
|
self, execute_model_req: ExecuteModelRequest
|
|
) -> list[SamplerOutput]:
|
|
if self.parallel_worker_tasks is None:
|
|
# Start model execution loop running in the parallel workers
|
|
self.parallel_worker_tasks = asyncio.create_task(
|
|
self._start_worker_execution_loop()
|
|
)
|
|
|
|
# Only the driver worker returns the sampling results.
|
|
return await self._driver_execute_model_async(execute_model_req)
|
|
|
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
|
if self.parallel_worker_tasks is None:
|
|
return
|
|
|
|
await self._driver_execute_model_async()
|
|
parallel_worker_tasks = self.parallel_worker_tasks
|
|
self.parallel_worker_tasks = None
|
|
# Ensure that workers exit model loop cleanly
|
|
# (this will raise otherwise)
|
|
await parallel_worker_tasks
|
|
|
|
@abstractmethod
|
|
async def _driver_execute_model_async(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest | None = None,
|
|
) -> list[SamplerOutput]:
|
|
"""Execute the model asynchronously in the driver worker.
|
|
|
|
Passing None will cause the driver to stop the model execution
|
|
loop running in each of the remote workers.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def _start_worker_execution_loop(self):
|
|
"""Run execution loop on all workers. It guarantees all workers run
|
|
the loop or None of them is running the loop. Loop can be stopped by
|
|
`stop_remote_worker_execution_loop`.
|
|
The API is idempotent (guarantee only 1 loop run at any moment)."""
|
|
raise NotImplementedError
|