mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[DisaggEverything] Tokens in<>out /generate endpoint (#24261)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
49
examples/online_serving/token_generation_client.py
Normal file
49
examples/online_serving/token_generation_client.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import httpx
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate"
|
||||
DUMMY_API_KEY = "empty"
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
transport = httpx.HTTPTransport()
|
||||
headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"}
|
||||
client = httpx.Client(
|
||||
transport=transport,
|
||||
base_url=GEN_ENDPOINT,
|
||||
timeout=600,
|
||||
headers=headers,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "How many countries are in the EU?"},
|
||||
]
|
||||
|
||||
|
||||
def main(client):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False},
|
||||
"stream": False,
|
||||
}
|
||||
resp = client.post(GEN_ENDPOINT, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
print(data)
|
||||
print("-" * 50)
|
||||
print("Token generation results:")
|
||||
res = tokenizer.decode(data["choices"][0]["token_ids"])
|
||||
print(res)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(client)
|
||||
@@ -10,3 +10,7 @@ mkdocs-minify-plugin
|
||||
regex
|
||||
ruff
|
||||
pydantic
|
||||
|
||||
# For generating argparse docs.
|
||||
# Adding requirements here should only be used as a last resort.
|
||||
msgspec # Need for multiple inheritance involving msgspec.Struct
|
||||
262
tests/entrypoints/openai/test_serving_tokens.py
Normal file
262
tests/entrypoints/openai/test_serving_tokens.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.v1.engine.detokenizer import check_stop_strings
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
GEN_ENDPOINT = "/inference/v1/generate"
|
||||
|
||||
|
||||
def get_vocab_size(model_name):
|
||||
config = ModelConfig(
|
||||
model=model_name,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
return config.get_vocab_size()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def messages():
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "How many countries are in the EU?"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(request):
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
extra_args = getattr(request, "param", None)
|
||||
if extra_args is not None:
|
||||
args = args + (
|
||||
list(extra_args)
|
||||
if isinstance(extra_args, (list, tuple))
|
||||
else [str(extra_args)]
|
||||
)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
|
||||
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport,
|
||||
base_url=server.url_root,
|
||||
timeout=600,
|
||||
headers=headers,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_endpoint(client):
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": [1, 2, 3],
|
||||
"sampling_params": {"max_tokens": 5},
|
||||
"stream": False,
|
||||
}
|
||||
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
assert "choices" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_response_as_chat_completions(client, tokenizer, messages):
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # default with Qwen3
|
||||
)
|
||||
for ignore_eos in [True, False]:
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
# NOTE coordinator will set this to skip detokenization
|
||||
"detokenize": False,
|
||||
"ignore_eos": ignore_eos,
|
||||
},
|
||||
"stream": False,
|
||||
}
|
||||
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
generate_data = generate_resp.json()
|
||||
generate_res = tokenizer.decode(
|
||||
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"stream": False,
|
||||
"ignore_eos": ignore_eos,
|
||||
"chat_template_kwargs": dict(enable_thinking=False),
|
||||
}
|
||||
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||
completions_data = completions_resp.json()
|
||||
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||
|
||||
assert generate_res == completions_res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_string_workflow(client, tokenizer, messages):
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # default with Qwen3
|
||||
)
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"detokenize": False,
|
||||
# stop strings are only supported when detokenize is True.
|
||||
"stop": ["27 member"],
|
||||
},
|
||||
# TODO stream test is much more interesting
|
||||
"stream": False,
|
||||
}
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
generate_resp.raise_for_status()
|
||||
|
||||
payload["sampling_params"]["stop"] = None
|
||||
generate_resp = await client.post(
|
||||
GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"}
|
||||
)
|
||||
generate_data = generate_resp.json()
|
||||
generate_res = tokenizer.decode(
|
||||
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||
)
|
||||
|
||||
# NOTE This is under the responsibility of the coordinator
|
||||
# stop_checker = StopChecker(
|
||||
# max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer
|
||||
# )
|
||||
stop_str, truncate_to = check_stop_strings(
|
||||
generate_res, len(generate_res), ["27 member"], False
|
||||
)
|
||||
assert stop_str == "27 member"
|
||||
# abort request that hit stop string (requires tokens-only mode)
|
||||
# res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501
|
||||
# res.raise_for_status()
|
||||
generate_res = generate_res[:truncate_to]
|
||||
|
||||
# Get stop_str response from chat completions
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"stream": False,
|
||||
"stop": ["27 member"],
|
||||
"chat_template_kwargs": dict(enable_thinking=False),
|
||||
}
|
||||
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||
completions_data = completions_resp.json()
|
||||
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||
assert generate_res == completions_res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
[
|
||||
[
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
"Alice=charent/self_cognition_Alice",
|
||||
"Bob=charent/self_cognition_Bob",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_generate_with_lora_adapter(client, tokenizer, messages):
|
||||
# Verify adapters are listed
|
||||
models_resp = await client.get("/v1/models")
|
||||
models_resp.raise_for_status()
|
||||
models = {m["id"] for m in models_resp.json().get("data", [])}
|
||||
assert {"Alice", "Bob"}.issubset(models)
|
||||
|
||||
# Generate using a LoRA adapter by specifying its name as the model
|
||||
payload = {
|
||||
"model": "Alice",
|
||||
"token_ids": [1, 2, 3],
|
||||
"sampling_params": {"max_tokens": 5},
|
||||
"stream": False,
|
||||
}
|
||||
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
assert "choices" in data
|
||||
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # default with Qwen3
|
||||
)
|
||||
payload = {
|
||||
"model": "Alice",
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"detokenize": False,
|
||||
},
|
||||
"stream": False,
|
||||
}
|
||||
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
generate_data = generate_resp.json()
|
||||
generate_res = tokenizer.decode(
|
||||
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": "Alice",
|
||||
"messages": messages,
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"stream": False,
|
||||
"chat_template_kwargs": dict(enable_thinking=False),
|
||||
}
|
||||
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||
completions_data = completions_resp.json()
|
||||
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||
|
||||
assert generate_res == completions_res
|
||||
@@ -566,6 +566,7 @@ class EngineArgs:
|
||||
kv_offloading_backend: KVOffloadingBackend | None = (
|
||||
CacheConfig.kv_offloading_backend
|
||||
)
|
||||
tokens_only: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
@@ -1495,6 +1496,10 @@ class EngineArgs:
|
||||
else ParallelConfig.data_parallel_rpc_port
|
||||
)
|
||||
|
||||
if self.tokens_only and not model_config.skip_tokenizer_init:
|
||||
model_config.skip_tokenizer_init = True
|
||||
logger.info("Skipping tokenizer initialization for tokens-only mode.")
|
||||
|
||||
# Forward the deprecated CLI args to the EPLB config.
|
||||
if self.num_redundant_experts is not None:
|
||||
self.eplb_config.num_redundant_experts = self.num_redundant_experts
|
||||
|
||||
@@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
EmbeddingResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
@@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
@@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
@@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
|
||||
app = sagemaker_standards.bootstrap(app)
|
||||
# Optional endpoints
|
||||
if args.tokens_only:
|
||||
|
||||
@app.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
return app
|
||||
|
||||
@@ -1851,6 +1918,20 @@ async def init_app_state(
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.serving_tokens = (
|
||||
ServingTokens(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
force_no_detokenize=args.tokens_only,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
|
||||
@@ -189,6 +189,11 @@ class FrontendArgs:
|
||||
Helps mitigate header abuse. Default: 256."""
|
||||
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
||||
"""If set to True, log the stack trace of error responses"""
|
||||
tokens_only: bool = False
|
||||
"""
|
||||
If set to True, only enable the Tokens In<>Out endpoint.
|
||||
This is intended for use in a Disaggregated Everything setup.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
|
||||
@@ -3220,3 +3220,80 @@ class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
|
||||
words: list[TranslationWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
model: str | None = None
|
||||
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
|
||||
class GenerateResponseChoice(BaseModel):
|
||||
index: int
|
||||
logprobs: ChatCompletionLogProbs | None = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: str | None = "stop"
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
choices: list[GenerateResponseChoice]
|
||||
|
||||
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
@@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse,
|
||||
RerankRequest,
|
||||
@@ -134,6 +136,7 @@ AnyRequest: TypeAlias = (
|
||||
| SpeechToTextRequest
|
||||
| ResponsesRequest
|
||||
| IOProcessorRequest
|
||||
| GenerateRequest
|
||||
)
|
||||
|
||||
AnyResponse: TypeAlias = (
|
||||
@@ -145,6 +148,7 @@ AnyResponse: TypeAlias = (
|
||||
| PoolingResponse
|
||||
| ClassificationResponse
|
||||
| ScoreResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
269
vllm/entrypoints/openai/serving_tokens.py
Normal file
269
vllm/entrypoints/openai/serving_tokens.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
# yapf: disable
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb,
|
||||
ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.collection_utils import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingTokens(OpenAIServing):
|
||||
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
force_no_detokenize: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_log_outputs: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack)
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
self.force_no_detokenize = force_no_detokenize
|
||||
if force_no_detokenize:
|
||||
logger.info("Tokens-only mode is enabled, skipping detokenization "
|
||||
"step for incoming requests.")
|
||||
|
||||
async def serve_tokens(
|
||||
self,
|
||||
request: GenerateRequest,
|
||||
raw_request: Request | None = None
|
||||
) -> GenerateResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
lora_request = None
|
||||
lora_request = self._maybe_get_adapters(request,
|
||||
supports_default_mm_loras=True)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
request_id = "generate-tokens-" \
|
||||
f"{self._base_request_id(raw_request, request.request_id)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
|
||||
# completed
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
|
||||
if request.features is not None:
|
||||
engine_prompt["multi_modal_data"] = None
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
result_generator: AsyncGenerator[RequestOutput, None] | None = None
|
||||
try:
|
||||
sampling_params = request.sampling_params
|
||||
if self.force_no_detokenize:
|
||||
sampling_params.detokenize = False
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.token_ids,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# TODO(NickLucche): Implement streaming response
|
||||
|
||||
try:
|
||||
assert result_generator is not None
|
||||
return await self.serve_tokens_full_generator(
|
||||
request, result_generator, request_id, model_name,
|
||||
request_metadata)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def serve_tokens_full_generator(
|
||||
self,
|
||||
request: GenerateRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> ErrorResponse | GenerateResponse:
|
||||
|
||||
created_time = int(time.time())
|
||||
final_res: RequestOutput | None = None
|
||||
sampling_params: SamplingParams = request.sampling_params
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: list[GenerateResponseChoice] = []
|
||||
num_generated_tokens = 0
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# This is top_logprobs in completions API
|
||||
if sampling_params.logprobs:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_tokens_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=sampling_params.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = GenerateResponseChoice(
|
||||
index=output.index,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if output.finish_reason else "stop",
|
||||
token_ids=as_list(output.token_ids))
|
||||
|
||||
choices.append(choice_data)
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
num_generated_tokens)
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
# This info is not available at the /coordinator level
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = GenerateResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Log complete response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
for choice in choices:
|
||||
# Get the corresponding output token IDs
|
||||
output_token_ids = None
|
||||
if choice.index < len(final_res.outputs):
|
||||
output_token_ids = final_res.outputs[
|
||||
choice.index].token_ids
|
||||
|
||||
if output_token_ids:
|
||||
# Log token_ids only.
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs="",
|
||||
output_token_ids=output_token_ids,
|
||||
finish_reason=choice.finish_reason,
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _create_tokens_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
num_output_top_logprobs: int | None = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
token = f"token_id:{token_id}"
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None or step_top_logprobs.get(
|
||||
token_id) is None:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(token=token, ))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
top_logprobs=[
|
||||
ChatCompletionLogProb(
|
||||
token=token,
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
) for i, p in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs
|
||||
and i < num_output_top_logprobs
|
||||
]))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
@@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.serial_utils import PydanticMsgspecMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -122,6 +123,7 @@ class RequestOutputKind(Enum):
|
||||
|
||||
|
||||
class SamplingParams(
|
||||
PydanticMsgspecMixin,
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
from vllm.v1.serial_utils import UtilityResult
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
@@ -131,13 +132,6 @@ class EngineCoreOutput(
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityResult:
|
||||
"""Wrapper for special handling when serializing/deserializing."""
|
||||
|
||||
def __init__(self, r: Any = None):
|
||||
self.result = r
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any, TypeAlias, get_type_hints
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
@@ -16,6 +16,8 @@ import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -32,7 +34,6 @@ from vllm.multimodal.inputs import (
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.engine import UtilityResult
|
||||
from vllm.v1.utils import tensor_data
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -104,6 +105,13 @@ def _decode_type_info_recursive(
|
||||
return convert_fn(type_info, data)
|
||||
|
||||
|
||||
class UtilityResult:
|
||||
"""Wrapper for special handling when serializing/deserializing."""
|
||||
|
||||
def __init__(self, r: Any = None):
|
||||
self.result = r
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
@@ -469,3 +477,56 @@ def run_method(
|
||||
else:
|
||||
func = partial(method, obj) # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class PydanticMsgspecMixin:
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
"""
|
||||
Make msgspec.Struct compatible with Pydantic, respecting defaults.
|
||||
Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
|
||||
API as input or in `/docs`. Note this is cached by Pydantic and not
|
||||
called on every validation.
|
||||
"""
|
||||
msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
|
||||
type_hints = get_type_hints(source_type)
|
||||
|
||||
# Build the Pydantic typed_dict_field for each msgspec field
|
||||
fields = {}
|
||||
for name, hint in type_hints.items():
|
||||
msgspec_field = msgspec_fields[name]
|
||||
|
||||
# typed_dict_field using the handler to get the schema
|
||||
field_schema = handler(hint)
|
||||
|
||||
# Add default value to the schema.
|
||||
if msgspec_field.default_factory is not msgspec.NODEFAULT:
|
||||
wrapped_schema = core_schema.with_default_schema(
|
||||
schema=field_schema,
|
||||
default_factory=msgspec_field.default_factory,
|
||||
)
|
||||
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||
elif msgspec_field.default is not msgspec.NODEFAULT:
|
||||
wrapped_schema = core_schema.with_default_schema(
|
||||
schema=field_schema,
|
||||
default=msgspec_field.default,
|
||||
)
|
||||
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||
else:
|
||||
# No default, so Pydantic will treat it as required
|
||||
fields[name] = core_schema.typed_dict_field(field_schema)
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls._validate_msgspec,
|
||||
core_schema.typed_dict_schema(fields),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_msgspec(cls, value: Any) -> Any:
|
||||
"""Validate and convert input to msgspec.Struct instance."""
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cls(**value)
|
||||
return msgspec.convert(value, type=cls)
|
||||
|
||||
Reference in New Issue
Block a user