[DisaggEverything] Tokens in<>out /generate endpoint (#24261)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Nicolò Lucchesi
2025-11-14 17:58:01 +01:00
committed by GitHub
parent d54a18a47e
commit 6f1e7f7226
12 changed files with 822 additions and 9 deletions

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import httpx
from transformers import AutoTokenizer
GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate"
DUMMY_API_KEY = "empty"
MODEL_NAME = "Qwen/Qwen3-0.6B"
transport = httpx.HTTPTransport()
headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"}
client = httpx.Client(
transport=transport,
base_url=GEN_ENDPOINT,
timeout=600,
headers=headers,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "How many countries are in the EU?"},
]
def main(client):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False,
)
payload = {
"model": MODEL_NAME,
"token_ids": token_ids,
"sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False},
"stream": False,
}
resp = client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
print(data)
print("-" * 50)
print("Token generation results:")
res = tokenizer.decode(data["choices"][0]["token_ids"])
print(res)
print("-" * 50)
if __name__ == "__main__":
main(client)

View File

@@ -10,3 +10,7 @@ mkdocs-minify-plugin
regex
ruff
pydantic
# For generating argparse docs.
# Adding requirements here should only be used as a last resort.
msgspec # Need for multiple inheritance involving msgspec.Struct

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import httpx
import pytest
import pytest_asyncio
from transformers import AutoTokenizer
from vllm.config import ModelConfig
from vllm.v1.engine.detokenizer import check_stop_strings
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
GEN_ENDPOINT = "/inference/v1/generate"
def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
seed=0,
dtype="bfloat16",
)
return config.get_vocab_size()
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def messages():
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "How many countries are in the EU?"},
]
@pytest.fixture(scope="module")
def server(request):
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
]
extra_args = getattr(request, "param", None)
if extra_args is not None:
args = args + (
list(extra_args)
if isinstance(extra_args, (list, tuple))
else [str(extra_args)]
)
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
async with httpx.AsyncClient(
transport=transport,
base_url=server.url_root,
timeout=600,
headers=headers,
) as c:
yield c
@pytest.mark.asyncio
async def test_generate_endpoint(client):
payload = {
"model": MODEL_NAME,
"token_ids": [1, 2, 3],
"sampling_params": {"max_tokens": 5},
"stream": False,
}
resp = await client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
assert "choices" in data
@pytest.mark.asyncio
async def test_same_response_as_chat_completions(client, tokenizer, messages):
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False, # default with Qwen3
)
for ignore_eos in [True, False]:
payload = {
"model": MODEL_NAME,
"token_ids": token_ids,
"sampling_params": {
"max_tokens": 24,
"temperature": 0.0,
# NOTE coordinator will set this to skip detokenization
"detokenize": False,
"ignore_eos": ignore_eos,
},
"stream": False,
}
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_data = generate_resp.json()
generate_res = tokenizer.decode(
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
)
payload = {
"model": MODEL_NAME,
"messages": messages,
"max_tokens": 24,
"temperature": 0.0,
"stream": False,
"ignore_eos": ignore_eos,
"chat_template_kwargs": dict(enable_thinking=False),
}
completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"]
assert generate_res == completions_res
@pytest.mark.asyncio
async def test_stop_string_workflow(client, tokenizer, messages):
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False, # default with Qwen3
)
payload = {
"model": MODEL_NAME,
"token_ids": token_ids,
"sampling_params": {
"max_tokens": 24,
"temperature": 0.0,
"detokenize": False,
# stop strings are only supported when detokenize is True.
"stop": ["27 member"],
},
# TODO stream test is much more interesting
"stream": False,
}
with pytest.raises(httpx.HTTPStatusError):
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_resp.raise_for_status()
payload["sampling_params"]["stop"] = None
generate_resp = await client.post(
GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"}
)
generate_data = generate_resp.json()
generate_res = tokenizer.decode(
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
)
# NOTE This is under the responsibility of the coordinator
# stop_checker = StopChecker(
# max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer
# )
stop_str, truncate_to = check_stop_strings(
generate_res, len(generate_res), ["27 member"], False
)
assert stop_str == "27 member"
# abort request that hit stop string (requires tokens-only mode)
# res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501
# res.raise_for_status()
generate_res = generate_res[:truncate_to]
# Get stop_str response from chat completions
payload = {
"model": MODEL_NAME,
"messages": messages,
"max_tokens": 24,
"temperature": 0.0,
"stream": False,
"stop": ["27 member"],
"chat_template_kwargs": dict(enable_thinking=False),
}
completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"]
assert generate_res == completions_res
@pytest.mark.asyncio
@pytest.mark.parametrize(
"server",
[
[
"--enable-lora",
"--lora-modules",
"Alice=charent/self_cognition_Alice",
"Bob=charent/self_cognition_Bob",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
]
],
indirect=True,
)
async def test_generate_with_lora_adapter(client, tokenizer, messages):
# Verify adapters are listed
models_resp = await client.get("/v1/models")
models_resp.raise_for_status()
models = {m["id"] for m in models_resp.json().get("data", [])}
assert {"Alice", "Bob"}.issubset(models)
# Generate using a LoRA adapter by specifying its name as the model
payload = {
"model": "Alice",
"token_ids": [1, 2, 3],
"sampling_params": {"max_tokens": 5},
"stream": False,
}
resp = await client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
assert "choices" in data
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False, # default with Qwen3
)
payload = {
"model": "Alice",
"token_ids": token_ids,
"sampling_params": {
"max_tokens": 24,
"temperature": 0.0,
"detokenize": False,
},
"stream": False,
}
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_data = generate_resp.json()
generate_res = tokenizer.decode(
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
)
payload = {
"model": "Alice",
"messages": messages,
"max_tokens": 24,
"temperature": 0.0,
"stream": False,
"chat_template_kwargs": dict(enable_thinking=False),
}
completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"]
assert generate_res == completions_res

View File

@@ -566,6 +566,7 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend | None = (
CacheConfig.kv_offloading_backend
)
tokens_only: bool = False
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@@ -1495,6 +1496,10 @@ class EngineArgs:
else ParallelConfig.data_parallel_rpc_port
)
if self.tokens_only and not model_config.skip_tokenizer_init:
model_config.skip_tokenizer_init = True
logger.info("Skipping tokenizer initialization for tokens-only mode.")
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts

View File

@@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse,
ErrorInfo,
ErrorResponse,
GenerateRequest,
GenerateResponse,
IOProcessorResponse,
PoolingBytesResponse,
PoolingRequest,
@@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.entrypoints.openai.serving_tokens import ServingTokens
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription,
OpenAIServingTranslation,
@@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
@@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [
]
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
@@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI:
)
app = sagemaker_standards.bootstrap(app)
# Optional endpoints
if args.tokens_only:
@app.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
return app
@@ -1851,6 +1918,20 @@ async def init_app_state(
if "generate" in supported_tasks
else None
)
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
)
if "generate" in supported_tasks
else None
)
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0

View File

@@ -189,6 +189,11 @@ class FrontendArgs:
Helps mitigate header abuse. Default: 256."""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses"""
tokens_only: bool = False
"""
If set to True, only enable the Tokens In<>Out endpoint.
This is intended for use in a Disaggregated Everything setup.
"""
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

View File

@@ -3220,3 +3220,80 @@ class TranslationResponseVerbose(OpenAIBaseModel):
words: list[TranslationWord] | None = None
"""Extracted words and their corresponding timestamps."""
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
class GenerateResponseChoice(BaseModel):
index: int
logprobs: ChatCompletionLogProbs | None = None
# per OpenAI spec this is the default
finish_reason: str | None = "stop"
token_ids: list[int] | None = None
class GenerateResponse(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices: list[GenerateResponseChoice]
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)

View File

@@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse,
FunctionCall,
FunctionDefinition,
GenerateRequest,
GenerateResponse,
IOProcessorRequest,
PoolingResponse,
RerankRequest,
@@ -134,6 +136,7 @@ AnyRequest: TypeAlias = (
| SpeechToTextRequest
| ResponsesRequest
| IOProcessorRequest
| GenerateRequest
)
AnyResponse: TypeAlias = (
@@ -145,6 +148,7 @@ AnyResponse: TypeAlias = (
| PoolingResponse
| ClassificationResponse
| ScoreResponse
| GenerateResponse
)

View File

@@ -0,0 +1,269 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
from collections.abc import Sequence as GenericSequence
from fastapi import Request
# yapf: disable
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
class ServingTokens(OpenAIServing):
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
force_no_detokenize: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_prompt_tokens_details: bool = False,
enable_log_outputs: bool = False,
):
super().__init__(engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize
if force_no_detokenize:
logger.info("Tokens-only mode is enabled, skipping detokenization "
"step for incoming requests.")
async def serve_tokens(
self,
request: GenerateRequest,
raw_request: Request | None = None
) -> GenerateResponse | ErrorResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
lora_request = None
lora_request = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
request_id = "generate-tokens-" \
f"{self._base_request_id(raw_request, request.request_id)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
try:
sampling_params = request.sampling_params
if self.force_no_detokenize:
sampling_params.detokenize = False
self._log_inputs(request_id,
request.token_ids,
params=sampling_params,
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
except ValueError as e:
return self.create_error_response(str(e))
# TODO(NickLucche): Implement streaming response
try:
assert result_generator is not None
return await self.serve_tokens_full_generator(
request, result_generator, request_id, model_name,
request_metadata)
except ValueError as e:
return self.create_error_response(str(e))
async def serve_tokens_full_generator(
self,
request: GenerateRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str,
model_name: str,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | GenerateResponse:
created_time = int(time.time())
final_res: RequestOutput | None = None
sampling_params: SamplingParams = request.sampling_params
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(str(e))
assert final_res is not None
choices: list[GenerateResponseChoice] = []
num_generated_tokens = 0
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
# This is top_logprobs in completions API
if sampling_params.logprobs:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_tokens_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=sampling_params.logprobs,
)
else:
logprobs = None
choice_data = GenerateResponseChoice(
index=output.index,
logprobs=logprobs,
finish_reason=output.finish_reason
if output.finish_reason else "stop",
token_ids=as_list(output.token_ids))
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
# This info is not available at the /coordinator level
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
response = GenerateResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
for choice in choices:
# Get the corresponding output token IDs
output_token_ids = None
if choice.index < len(final_res.outputs):
output_token_ids = final_res.outputs[
choice.index].token_ids
if output_token_ids:
# Log token_ids only.
self.request_logger.log_outputs(
request_id=request_id,
outputs="",
output_token_ids=output_token_ids,
finish_reason=choice.finish_reason,
is_streaming=False,
delta=False,
)
return response
def _create_tokens_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int | None = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content: list[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
token = f"token_id:{token_id}"
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None or step_top_logprobs.get(
token_id) is None:
logprobs_content.append(
ChatCompletionLogProbsContent(token=token, ))
else:
step_token = step_top_logprobs[token_id]
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
logprob=max(step_token.logprob, -9999.0),
top_logprobs=[
ChatCompletionLogProb(
token=token,
logprob=max(p[1].logprob, -9999.0),
) for i, p in enumerate(step_top_logprobs.items())
if num_output_top_logprobs
and i < num_output_top_logprobs
]))
return ChatCompletionLogProbs(content=logprobs_content)

View File

@@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__)
@@ -122,6 +123,7 @@ class RequestOutputKind(Enum):
class SamplingParams(
PydanticMsgspecMixin,
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.

View File

@@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
@@ -131,13 +132,6 @@ class EngineCoreOutput(
return self.finish_reason is not None
class UtilityResult:
"""Wrapper for special handling when serializing/deserializing."""
def __init__(self, r: Any = None):
self.result = r
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]

View File

@@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
from functools import partial
from inspect import isclass
from types import FunctionType
from typing import Any, TypeAlias
from typing import Any, TypeAlias, get_type_hints
import cloudpickle
import msgspec
@@ -16,6 +16,8 @@ import numpy as np
import torch
import zmq
from msgspec import msgpack
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from vllm import envs
from vllm.logger import init_logger
@@ -32,7 +34,6 @@ from vllm.multimodal.inputs import (
NestedTensors,
)
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.engine import UtilityResult
from vllm.v1.utils import tensor_data
logger = init_logger(__name__)
@@ -104,6 +105,13 @@ def _decode_type_info_recursive(
return convert_fn(type_info, data)
class UtilityResult:
"""Wrapper for special handling when serializing/deserializing."""
def __init__(self, r: Any = None):
self.result = r
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
@@ -469,3 +477,56 @@ def run_method(
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)
class PydanticMsgspecMixin:
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""
Make msgspec.Struct compatible with Pydantic, respecting defaults.
Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
API as input or in `/docs`. Note this is cached by Pydantic and not
called on every validation.
"""
msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
type_hints = get_type_hints(source_type)
# Build the Pydantic typed_dict_field for each msgspec field
fields = {}
for name, hint in type_hints.items():
msgspec_field = msgspec_fields[name]
# typed_dict_field using the handler to get the schema
field_schema = handler(hint)
# Add default value to the schema.
if msgspec_field.default_factory is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema(
schema=field_schema,
default_factory=msgspec_field.default_factory,
)
fields[name] = core_schema.typed_dict_field(wrapped_schema)
elif msgspec_field.default is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema(
schema=field_schema,
default=msgspec_field.default,
)
fields[name] = core_schema.typed_dict_field(wrapped_schema)
else:
# No default, so Pydantic will treat it as required
fields[name] = core_schema.typed_dict_field(field_schema)
return core_schema.no_info_after_validator_function(
cls._validate_msgspec,
core_schema.typed_dict_schema(fields),
)
@classmethod
def _validate_msgspec(cls, value: Any) -> Any:
"""Validate and convert input to msgspec.Struct instance."""
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(**value)
return msgspec.convert(value, type=cls)