[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)

Signed-off-by: Andrew Xia <axia@fb.com>
Signed-off-by: Andrew Xia <axia@meta.com>
Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
Andrew Xia
2025-12-05 08:11:50 -08:00
committed by GitHub
parent 949a6a19d2
commit da7bc54ea8
8 changed files with 347 additions and 16 deletions

View File

@@ -3,7 +3,7 @@
"""
Set up this example by starting a vLLM OpenAI-compatible server with tool call
options enabled.
Reasoning models can be used through the Responses API as seen here
Reasoning models can be used through the Responses API as seen here
https://platform.openai.com/docs/api-reference/responses
For example:
vllm serve Qwen/Qwen3-1.7B --reasoning-parser qwen3 \

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import pytest
import pytest_asyncio
@@ -13,12 +15,27 @@ MODEL_NAME = "Qwen/Qwen3-8B"
@pytest.fixture(scope="module")
def server():
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed"
)
args = [
"--reasoning-parser",
"qwen3",
"--max_model_len",
"5000",
"--structured-outputs-config.backend",
"xgrammar",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--tool-server",
"demo",
]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
# uncomment for tool calling
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
)
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
@@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
assert response.output[0].type == "reasoning"
assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_first_turn(client: OpenAI, model_name: str):
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]
response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"
function_call = response.output[1]
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
args = json.loads(function_call.arguments)
assert "sign" in args
# the multi turn function call is tested above in
# test_reasoning_and_function_items
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert response.output[0].type == "reasoning"
assert response.output[1].type == "mcp_call"
assert type(response.output[1].arguments) is str
assert type(response.output[1].output) is str
assert response.output[2].type == "reasoning"
# make sure the correct math is in the final output
assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text

View File

@@ -9,10 +9,16 @@ from collections.abc import Callable
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
@@ -22,16 +28,20 @@ from vllm.entrypoints.openai.parser.responses_parser import (
get_responses_parser_for_simple_context,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponseRawMessageAndToken,
ResponsesRequest,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.entrypoints.responses_utils import construct_tool_dicts
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
if TYPE_CHECKING:
from mcp.client import ClientSession
@@ -221,6 +231,10 @@ class ParsableContext(ConversationContext):
tokenizer: AnyTokenizer,
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
self.num_prompt_tokens = 0
self.num_output_tokens = 0
@@ -238,12 +252,19 @@ class ParsableContext(ConversationContext):
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.tokenizer = tokenizer
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
@@ -252,14 +273,50 @@ class ParsableContext(ConversationContext):
self.parser.process(output.outputs[0])
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
raise NotImplementedError("Should not be called.")
self.parser.response_messages.extend(output)
def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call"""
last_message = self.parser.response_messages[-1]
# TODO: figure out which tools are MCP tools
if ( # noqa: SIM103
last_message.type == "function_call"
and last_message.name in ("code_interpreter", "python")
):
return True
return False
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)
return [message]
async def call_tool(self) -> list[ResponseInputOutputItem]:
raise NotImplementedError("Should not be called.")
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
return []
def render_for_completion(self):
raise NotImplementedError("Should not be called.")
@@ -271,11 +328,38 @@ class ParsableContext(ConversationContext):
request_id: str,
mcp_tools: dict[str, Mcp],
):
pass
if tool_server:
for tool_name in self.available_tools:
if tool_name in self._tool_sessions:
continue
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
raise NotImplementedError("Should not be called.")
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)
class HarmonyContext(ConversationContext):

View File

@@ -3,6 +3,7 @@
import logging
from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
@@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
@@ -29,6 +32,7 @@ class ResponsesParser:
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
@@ -39,6 +43,9 @@ class ResponsesParser:
self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
def process(self, output: CompletionOutput) -> "ResponsesParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
@@ -59,6 +66,29 @@ class ResponsesParser:
)
)
function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
if content:
self.response_messages.append(
ResponseOutputMessage(
@@ -76,6 +106,8 @@ class ResponsesParser:
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)
return self
@@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
@@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)

View File

@@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
from vllm.entrypoints.context import (
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
@@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreRequest,
ScoreResponse,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12):
from typing import TypedDict
@@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
@@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import (
construct_input_messages,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType
@@ -1224,6 +1236,31 @@ class OpenAIServing:
)
return engine_request, tokenization_kwargs
async def _render_next_turn(
self,
request: ResponsesRequest,
tokenizer: AnyTokenizer,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
new_messages,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return request_prompts, engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
@@ -1286,11 +1323,27 @@ class OpenAIServing:
# Create inputs for the next turn.
# Render the next prompt token ids.
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn(
context.request,
context.tokenizer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1

View File

@@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
generators: list[AsyncGenerator[ConversationContext, None]] = []
builtin_tool_list: list[str] = []
if self.use_harmony and self.tool_server is not None:
if self.tool_server is not None:
if self.tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
@@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
available_tools=available_tools,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
else:
context = SimpleContext()

View File

@@ -16,6 +16,7 @@ from openai.types.responses.response import ToolChoice
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
@@ -25,6 +26,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ResponseInputOutputItem,
)
from vllm.utils import random_uuid
def make_response_output_items_from_parsable_context(
@@ -36,7 +38,24 @@ def make_response_output_items_from_parsable_context(
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
raise NotImplementedError("tool calls not supported for response context")
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"mcp_{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages

View File

@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai_harmony import Author, Message, Role, TextContent
from vllm.logger import init_logger
from vllm.utils import random_uuid
if TYPE_CHECKING:
# Avoid circular import.
@@ -46,6 +51,10 @@ class Tool(ABC):
async def get_result(self, context: "ConversationContext") -> Any:
pass
@abstractmethod
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
pass
class HarmonyBrowserTool(Tool):
def __init__(self):
@@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool):
tool_output_msgs.append(msg)
return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
raise NotImplementedError("Not implemented yet")
@property
def tool_config(self) -> Any:
return self.browser_tool.tool_config
@@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool):
tool_output_msgs.append(msg)
return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
"""
This function converts parsable context types to harmony and
back so we can use GPTOSS demo python tool
"""
from vllm.entrypoints.context import ParsableContext
assert isinstance(context, ParsableContext)
last_msg = context.parser.response_messages[-1]
args = json.loads(last_msg.arguments)
last_msg_harmony = Message(
author=Author(role="assistant", name=None),
content=[TextContent(text=args["code"])],
channel="analysis",
recipient="python",
content_type="code",
)
tool_output_msgs = []
async for msg in self.python_tool.process(last_msg_harmony):
processed = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=msg.content[0].text,
status="completed",
)
tool_output_msgs.append(processed)
return tool_output_msgs
@property
def tool_config(self) -> Any:
return self.python_tool.tool_config