mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)
Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Andrew Xia <axia@meta.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
"""
|
||||
Set up this example by starting a vLLM OpenAI-compatible server with tool call
|
||||
options enabled.
|
||||
Reasoning models can be used through the Responses API as seen here
|
||||
Reasoning models can be used through the Responses API as seen here
|
||||
https://platform.openai.com/docs/api-reference/responses
|
||||
For example:
|
||||
vllm serve Qwen/Qwen3-1.7B --reasoning-parser qwen3 \
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@@ -13,12 +15,27 @@ MODEL_NAME = "Qwen/Qwen3-8B"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
|
||||
assert importlib.util.find_spec("gpt_oss") is not None, (
|
||||
"Harmony tests require gpt_oss package to be installed"
|
||||
)
|
||||
|
||||
args = [
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--max_model_len",
|
||||
"5000",
|
||||
"--structured-outputs-config.backend",
|
||||
"xgrammar",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--tool-server",
|
||||
"demo",
|
||||
]
|
||||
env_dict = dict(
|
||||
VLLM_ENABLE_RESPONSES_API_STORE="1",
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
|
||||
# uncomment for tool calling
|
||||
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
||||
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
||||
)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||
@@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "message"
|
||||
assert type(response.output[1].content[0].text) is str
|
||||
|
||||
|
||||
def get_horoscope(sign):
|
||||
return f"{sign}: Next Tuesday you will befriend a baby otter."
|
||||
|
||||
|
||||
def call_function(name, args):
|
||||
if name == "get_horoscope":
|
||||
return get_horoscope(**args)
|
||||
else:
|
||||
raise ValueError(f"Unknown function: {name}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_call_first_turn(client: OpenAI, model_name: str):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_horoscope",
|
||||
"description": "Get today's horoscope for an astrological sign.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sign": {"type": "string"},
|
||||
},
|
||||
"required": ["sign"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="What is the horoscope for Aquarius today?",
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert len(response.output) == 2
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "function_call"
|
||||
|
||||
function_call = response.output[1]
|
||||
assert function_call.name == "get_horoscope"
|
||||
assert function_call.call_id is not None
|
||||
|
||||
args = json.loads(function_call.arguments)
|
||||
assert "sign" in args
|
||||
|
||||
# the multi turn function call is tested above in
|
||||
# test_reasoning_and_function_items
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_mcp_tool_call(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="What is 13 * 24? Use python to calculate the result.",
|
||||
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "mcp_call"
|
||||
assert type(response.output[1].arguments) is str
|
||||
assert type(response.output[1].output) is str
|
||||
assert response.output[2].type == "reasoning"
|
||||
# make sure the correct math is in the final output
|
||||
assert response.output[3].type == "message"
|
||||
assert "312" in response.output[3].content[0].text
|
||||
|
||||
@@ -9,10 +9,16 @@ from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
@@ -22,16 +28,20 @@ from vllm.entrypoints.openai.parser.responses_parser import (
|
||||
get_responses_parser_for_simple_context,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
FunctionCall,
|
||||
ResponseInputOutputItem,
|
||||
ResponseRawMessageAndToken,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
@@ -221,6 +231,10 @@ class ParsableContext(ConversationContext):
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
available_tools: list[str] | None,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
@@ -238,12 +252,19 @@ class ParsableContext(ConversationContext):
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
tool_parser_cls=tool_parser_cls,
|
||||
)
|
||||
self.tool_parser_cls = tool_parser_cls
|
||||
self.request = request
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.available_tools = available_tools or []
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format = chat_template_content_format
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
@@ -252,14 +273,50 @@ class ParsableContext(ConversationContext):
|
||||
self.parser.process(output.outputs[0])
|
||||
|
||||
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
self.parser.response_messages.extend(output)
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
"""Return true if the last message is a MCP tool call"""
|
||||
last_message = self.parser.response_messages[-1]
|
||||
# TODO: figure out which tools are MCP tools
|
||||
if ( # noqa: SIM103
|
||||
last_message.type == "function_call"
|
||||
and last_message.name in ("code_interpreter", "python")
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
|
||||
) -> list[ResponseInputOutputItem]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result_parsable_context(self)
|
||||
args = json.loads(last_msg.arguments)
|
||||
param = {
|
||||
"code": args["code"],
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
return [message]
|
||||
|
||||
async def call_tool(self) -> list[ResponseInputOutputItem]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
if not self.parser.response_messages:
|
||||
return []
|
||||
last_msg = self.parser.response_messages[-1]
|
||||
if last_msg.name == "code_interpreter":
|
||||
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
||||
return []
|
||||
|
||||
def render_for_completion(self):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
@@ -271,11 +328,38 @@ class ParsableContext(ConversationContext):
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
pass
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name in self._tool_sessions:
|
||||
continue
|
||||
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
@@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@@ -29,6 +32,7 @@ class ResponsesParser:
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
):
|
||||
self.response_messages: list[ResponseInputOutputItem] = (
|
||||
# TODO: initial messages may not be properly typed
|
||||
@@ -39,6 +43,9 @@ class ResponsesParser:
|
||||
self.request = request
|
||||
|
||||
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
|
||||
self.tool_parser_instance = None
|
||||
if tool_parser_cls is not None:
|
||||
self.tool_parser_instance = tool_parser_cls(tokenizer)
|
||||
|
||||
def process(self, output: CompletionOutput) -> "ResponsesParser":
|
||||
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
|
||||
@@ -59,6 +66,29 @@ class ResponsesParser:
|
||||
)
|
||||
)
|
||||
|
||||
function_calls: list[ResponseFunctionToolCall] = []
|
||||
if self.tool_parser_instance is not None:
|
||||
tool_call_info = self.tool_parser_instance.extract_tool_calls(
|
||||
content if content is not None else "",
|
||||
request=self.request, # type: ignore
|
||||
)
|
||||
if tool_call_info is not None and tool_call_info.tools_called:
|
||||
# extract_tool_calls() returns a list of tool calls.
|
||||
function_calls.extend(
|
||||
ResponseFunctionToolCall(
|
||||
id=f"fc_{random_uuid()}",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
for tool_call in tool_call_info.tool_calls
|
||||
)
|
||||
content = tool_call_info.content
|
||||
if content and content.strip() == "":
|
||||
content = None
|
||||
|
||||
if content:
|
||||
self.response_messages.append(
|
||||
ResponseOutputMessage(
|
||||
@@ -76,6 +106,8 @@ class ResponsesParser:
|
||||
],
|
||||
)
|
||||
)
|
||||
if len(function_calls) > 0:
|
||||
self.response_messages.extend(function_calls)
|
||||
|
||||
return self
|
||||
|
||||
@@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls,
|
||||
) -> ResponsesParser:
|
||||
"""Factory function to create a ResponsesParser with
|
||||
optional reasoning parser.
|
||||
@@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
tool_parser_cls=tool_parser_cls,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.entrypoints.context import (
|
||||
HarmonyContext,
|
||||
ParsableContext,
|
||||
StreamingHarmonyContext,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
FunctionCall,
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
@@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
@@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
DetokenizeRequest,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
ResponsesRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
@@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||
from vllm.entrypoints.responses_utils import (
|
||||
construct_input_messages,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import PromptType
|
||||
@@ -1224,6 +1236,31 @@ class OpenAIServing:
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: list[ResponseInputOutputItem],
|
||||
tool_dicts: list[dict[str, Any]] | None,
|
||||
tool_parser,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
new_messages = construct_input_messages(
|
||||
request_input=messages,
|
||||
)
|
||||
|
||||
_, request_prompts, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
new_messages,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
)
|
||||
return request_prompts, engine_prompts
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -1286,11 +1323,27 @@ class OpenAIServing:
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids.
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
elif isinstance(context, ParsableContext):
|
||||
request_prompts, engine_prompts = await self._render_next_turn(
|
||||
context.request,
|
||||
context.tokenizer,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
context.tool_parser_cls,
|
||||
context.chat_template,
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
engine_prompt = engine_prompts[0]
|
||||
request_prompt = request_prompts[0]
|
||||
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
|
||||
sampling_params.max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"]
|
||||
)
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
sub_request += 1
|
||||
|
||||
@@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
|
||||
builtin_tool_list: list[str] = []
|
||||
if self.use_harmony and self.tool_server is not None:
|
||||
if self.tool_server is not None:
|
||||
if self.tool_server.has_tool("browser"):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
@@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=self.reasoning_parser,
|
||||
request=request,
|
||||
tool_parser_cls=self.tool_parser,
|
||||
available_tools=available_tools,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
|
||||
@@ -16,6 +16,7 @@ from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
from openai.types.responses.tool import Tool
|
||||
@@ -25,6 +26,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionMessageParam,
|
||||
ResponseInputOutputItem,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
def make_response_output_items_from_parsable_context(
|
||||
@@ -36,7 +38,24 @@ def make_response_output_items_from_parsable_context(
|
||||
if not isinstance(message, ResponseFunctionToolCallOutputItem):
|
||||
output_messages.append(message)
|
||||
else:
|
||||
raise NotImplementedError("tool calls not supported for response context")
|
||||
if len(output_messages) == 0:
|
||||
raise ValueError(
|
||||
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
|
||||
)
|
||||
if isinstance(output_messages[-1], ResponseFunctionToolCall):
|
||||
mcp_message = McpCall(
|
||||
id=f"mcp_{random_uuid()}",
|
||||
arguments=output_messages[-1].arguments,
|
||||
name=output_messages[-1].name,
|
||||
server_label=output_messages[
|
||||
-1
|
||||
].name, # TODO: store the server label
|
||||
type="mcp_call",
|
||||
status="completed",
|
||||
output=message.output,
|
||||
# TODO: support error output
|
||||
)
|
||||
output_messages[-1] = mcp_message
|
||||
|
||||
return output_messages
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai_harmony import Author, Message, Role, TextContent
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import.
|
||||
@@ -46,6 +51,10 @@ class Tool(ABC):
|
||||
async def get_result(self, context: "ConversationContext") -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class HarmonyBrowserTool(Tool):
|
||||
def __init__(self):
|
||||
@@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool):
|
||||
tool_output_msgs.append(msg)
|
||||
return tool_output_msgs
|
||||
|
||||
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
@property
|
||||
def tool_config(self) -> Any:
|
||||
return self.browser_tool.tool_config
|
||||
@@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool):
|
||||
tool_output_msgs.append(msg)
|
||||
return tool_output_msgs
|
||||
|
||||
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
|
||||
"""
|
||||
This function converts parsable context types to harmony and
|
||||
back so we can use GPTOSS demo python tool
|
||||
"""
|
||||
from vllm.entrypoints.context import ParsableContext
|
||||
|
||||
assert isinstance(context, ParsableContext)
|
||||
|
||||
last_msg = context.parser.response_messages[-1]
|
||||
args = json.loads(last_msg.arguments)
|
||||
|
||||
last_msg_harmony = Message(
|
||||
author=Author(role="assistant", name=None),
|
||||
content=[TextContent(text=args["code"])],
|
||||
channel="analysis",
|
||||
recipient="python",
|
||||
content_type="code",
|
||||
)
|
||||
|
||||
tool_output_msgs = []
|
||||
async for msg in self.python_tool.process(last_msg_harmony):
|
||||
processed = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=msg.content[0].text,
|
||||
status="completed",
|
||||
)
|
||||
tool_output_msgs.append(processed)
|
||||
return tool_output_msgs
|
||||
|
||||
@property
|
||||
def tool_config(self) -> Any:
|
||||
return self.python_tool.tool_config
|
||||
|
||||
Reference in New Issue
Block a user