mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[Bugfix] Mistral tool parser streaming update (#19425)
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> Signed-off-by: Chauncey <chaunceyjiang@gmail.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Co-authored-by: Jeff Cook <jeff@jeffcook.io> Co-authored-by: sfbemerk <benjaminmerkel@mail.de> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -46,6 +46,7 @@ scipy # Required for phi-4-multimodal-instruct
|
||||
ninja # Required for xgrammar, rocm, tpu, xpu
|
||||
pybase64 # fast base64 implementation
|
||||
cbor2 # Required for cross-language serialization of hashable objects
|
||||
ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
|
||||
847
tests/tool_use/test_mistral_tool_parser.py
Normal file
847
tests/tool_use/test_mistral_tool_parser.py
Normal file
@@ -0,0 +1,847 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import AssistantMessage
|
||||
from mistral_common.protocol.instruct.request import InstructRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
|
||||
from vllm.tokenizers import (
|
||||
MistralTokenizer,
|
||||
TokenizerLike,
|
||||
get_tokenizer,
|
||||
)
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_pre_v11_tokenizer():
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
||||
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
|
||||
return MistralToolParser(mistral_pre_v11_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_tool_parser(mistral_tokenizer):
|
||||
return MistralToolParser(mistral_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) == 9
|
||||
|
||||
if isinstance(actual_tool_call, ToolCall):
|
||||
assert actual_tool_call.type == "function"
|
||||
elif isinstance(actual_tool_call, DeltaToolCall):
|
||||
assert actual_tool_call.function is not None
|
||||
assert actual_tool_call.function.name is not None
|
||||
assert actual_tool_call.function.arguments is not None
|
||||
assert actual_tool_call.function is not None
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name, (
|
||||
f"got wrong function name:${actual_tool_call.function.name}"
|
||||
)
|
||||
assert (
|
||||
actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
), f"got wrong function argument:${actual_tool_call.function.arguments}"
|
||||
|
||||
|
||||
def fix_tool_call_tokenization(
|
||||
tokens: list[int],
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
mistral_tokenizer: TokenizerLike,
|
||||
):
|
||||
"""
|
||||
Replaces the textual token sequence for [TOOL_CALLS]
|
||||
with its single special token ID.
|
||||
"""
|
||||
textual_tool_call_token_ids = mistral_tokenizer.encode(
|
||||
text=mistral_tool_parser.bot_token,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
|
||||
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
|
||||
|
||||
# If the input is too short to contain the sequence, no replacement is possible
|
||||
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
|
||||
return tokens
|
||||
|
||||
result_tokens = []
|
||||
i = 0
|
||||
target_len = len(textual_tool_call_token_ids)
|
||||
|
||||
while i < len(tokens):
|
||||
# Check if the slice from the current position matches the target sequence
|
||||
if tokens[i : i + target_len] == textual_tool_call_token_ids:
|
||||
# If it matches, add the replacement and jump the index forward
|
||||
result_tokens.extend(special_tool_call_token_ids)
|
||||
i += target_len
|
||||
else:
|
||||
# Otherwise, just add the current token and move to the next one
|
||||
result_tokens.append(tokens[i])
|
||||
i += 1
|
||||
|
||||
return result_tokens
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
mistral_tokenizer: TokenizerLike,
|
||||
model_output: str | None,
|
||||
tools: list[tuple[str, str]] | None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
if (
|
||||
isinstance(mistral_tokenizer, MistralTokenizer)
|
||||
and mistral_tokenizer.version >= 11
|
||||
):
|
||||
# With the newer versions of the tokenizer,
|
||||
# we cannot tokenize free text
|
||||
# so we need to create a list of messages to get tokenized
|
||||
assert tools is not None
|
||||
assistant_msg = AssistantMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=arg,
|
||||
)
|
||||
)
|
||||
for (name, arg) in tools
|
||||
],
|
||||
)
|
||||
request = InstructRequest(
|
||||
messages=[assistant_msg],
|
||||
)
|
||||
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
|
||||
else:
|
||||
# Older versions of the tokenizer are
|
||||
# able to encode directly the model's output (free text) into tokens
|
||||
assert model_output is not None
|
||||
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
||||
)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=mistral_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=None, # type: ignore[arg-type]
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
mistral_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def _test_extract_tool_calls_streaming(
|
||||
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
|
||||
):
|
||||
other_content: str = ""
|
||||
function_names: list[str] = []
|
||||
function_args_strs: list[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_ids: list[str | None] = []
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
tool_parser, tokenizer, model_output, tools
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
streamed_tool_calls = delta_message.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
assert len(tool_parser.prev_tool_call_arr) > 0
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
function_args_strs.append("")
|
||||
tool_call_ids.append(None)
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id and not tool_call_ids[tool_call.index]:
|
||||
tool_call_ids[tool_call.index] = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
function_args_strs[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert other_content == expected_content
|
||||
|
||||
actual_tool_calls = [
|
||||
ToolCall(
|
||||
id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR
|
||||
),
|
||||
),
|
||||
)
|
||||
for tool_call_id, function_name, function_args_str in zip(
|
||||
tool_call_ids, function_names, function_args_strs
|
||||
)
|
||||
]
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
_test_extract_tool_calls_streaming(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
None,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["tools", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
[("add", '{"a": 3, "b": 4}')],
|
||||
# [TOOL_CALLS]add{"a": 3, "b": 4}
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
[("add_two_strings", '{"a": "3", "b": "4"}')],
|
||||
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_two_strings",
|
||||
arguments=json.dumps({"a": "3", "b": "4"}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
[
|
||||
("add", '{"a": 3.5, "b": 4}'),
|
||||
(
|
||||
"get_current_weather",
|
||||
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
|
||||
),
|
||||
],
|
||||
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
tools,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
_test_extract_tool_calls_streaming(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
None,
|
||||
tools,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
"content_before_tool",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
# Additional content should not be after the tool calls
|
||||
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"bla",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_one_chunk(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
if isinstance(mistral_tokenizer, MistralTokenizer):
|
||||
all_token_ids = mistral_tokenizer.encode(model_output)
|
||||
else:
|
||||
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
||||
)
|
||||
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=all_token_ids,
|
||||
delta_token_ids=all_token_ids,
|
||||
request=None,
|
||||
) # type: ignore[arg-type]
|
||||
assert isinstance(delta_message, DeltaMessage)
|
||||
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
||||
|
||||
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
||||
|
||||
if delta_message.content is None:
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
|
||||
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
|
||||
else:
|
||||
all_token_ids = mistral_pre_v11_tokenizer.encode(
|
||||
model_output, add_special_tokens=False
|
||||
)
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
|
||||
)
|
||||
|
||||
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=all_token_ids,
|
||||
delta_token_ids=all_token_ids,
|
||||
request=None,
|
||||
) # type: ignore[arg-type]
|
||||
assert isinstance(delta_message, DeltaMessage)
|
||||
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
||||
|
||||
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
||||
|
||||
if delta_message.content is None:
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
@@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
"mistral": {
|
||||
"mistral-7b": {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
@@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
"supports_parallel": True,
|
||||
},
|
||||
"mistral-small-3.2": {
|
||||
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
"--tokenizer-mode",
|
||||
"mistral",
|
||||
"--config-format",
|
||||
"mistral",
|
||||
"--load-format",
|
||||
"mistral",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
# FIXME: This test currently fails, need to debug why.
|
||||
# "granite20b": {
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, auto
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
|
||||
import partial_json_parser
|
||||
import ijson
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
|
||||
@@ -32,6 +31,22 @@ logger = init_logger(__name__)
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class StreamingState(Enum):
|
||||
"""Enum for tracking the current streaming parsing state."""
|
||||
|
||||
WAITING_FOR_TOOL_START = auto()
|
||||
WAITING_FOR_TOOL_KEY = (
|
||||
auto()
|
||||
) # waiting for the "name" or "arguments" key to be complete
|
||||
PARSING_NAME = auto()
|
||||
PARSING_NAME_COMPLETED = auto()
|
||||
WAITING_FOR_ARGUMENTS_START = auto()
|
||||
PARSING_ARGUMENTS = auto()
|
||||
PARSING_ARGUMENTS_COMPLETED = auto()
|
||||
TOOL_COMPLETE = auto()
|
||||
ALL_TOOLS_COMPLETE = auto()
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@@ -46,8 +61,8 @@ class MistralToolCall(ToolCall):
|
||||
return id.isalnum() and len(id) == 9
|
||||
|
||||
|
||||
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
|
||||
return (
|
||||
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
|
||||
return not (
|
||||
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
|
||||
)
|
||||
|
||||
@@ -69,16 +84,22 @@ class MistralToolParser(ToolParser):
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START
|
||||
|
||||
# For streaming pre v11 tokenizer tool calls
|
||||
self.current_tool_name: str | None = None
|
||||
self.current_tool_mistral_id: str | None = None
|
||||
self.starting_new_tool = False
|
||||
if _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
self.parse_coro = ijson.parse_coro(
|
||||
self.update_stream_state_pre_v11_tokenizer()
|
||||
)
|
||||
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||
if not _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(
|
||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
|
||||
)
|
||||
@@ -131,18 +152,19 @@ class MistralToolParser(ToolParser):
|
||||
# jsons is difficult
|
||||
try:
|
||||
if self.fn_name_regex:
|
||||
matches = self.fn_name_regex.findall(tool_content)
|
||||
|
||||
function_call_arr = []
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
for single_tool_content in model_output.split(self.bot_token):
|
||||
matches = self.fn_name_regex.findall(single_tool_content)
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
@@ -193,198 +215,372 @@ class MistralToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.bot_token not in current_text:
|
||||
if self.bot_token_id not in current_token_ids:
|
||||
# if the tool call token is not in the tokens generated so far,
|
||||
# append output to contents since it's not a tool
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# if the tool call token IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the BOT token which means the start of tool
|
||||
# calling
|
||||
if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1:
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion any don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
# replace BOT token with empty string, and convert single quotes
|
||||
# to double to allow parsing as JSON since mistral uses single
|
||||
# quotes instead of double for tool calls
|
||||
parsable_arr = current_text.split(self.bot_token)[-1]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: list[dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags
|
||||
if _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
return self._extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
delta_text=delta_text,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (
|
||||
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||
):
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: str | None = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id], ""
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=MistralToolCall.generate_random_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
return self._extract_tool_calls_streaming(
|
||||
delta_text=delta_text, delta_token_ids=delta_token_ids
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("'", '"')
|
||||
if '"}' in new_text:
|
||||
new_text = new_text[: new_text.rindex('"}')]
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset mid-arguments"
|
||||
)
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[
|
||||
:-2
|
||||
]
|
||||
logger.debug("finding %s in %s", new_text, cur_arguments_json)
|
||||
|
||||
if new_text not in cur_arguments_json:
|
||||
return None
|
||||
arguments_delta = cur_arguments_json[
|
||||
: cur_arguments_json.rindex(new_text) + len(new_text)
|
||||
]
|
||||
logger.debug(
|
||||
"First tokens in arguments received: %s", arguments_delta
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
|
||||
logger.debug(
|
||||
"Searching for diff between \n%s\n%s",
|
||||
cur_args_json,
|
||||
prev_args_json,
|
||||
)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json
|
||||
)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
def _extract_tool_calls_streaming(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extracts tool calls for Mistral models
|
||||
doing tool calls of the following format:
|
||||
`[TOOL_CALLS]add{"a": 3.5, "b": 4}`
|
||||
"""
|
||||
additional_content: str = ""
|
||||
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
|
||||
# this is the first tool call
|
||||
assert self.bot_token_id in delta_token_ids
|
||||
if not delta_text.startswith(self.bot_token):
|
||||
additional_content += delta_text.split(self.bot_token)[0]
|
||||
delta_text = self.bot_token + "".join(
|
||||
delta_text.split(self.bot_token)[1:]
|
||||
)
|
||||
|
||||
delta_tool_calls = self._generate_delta_tool_call(delta_text)
|
||||
if not additional_content and len(delta_tool_calls) == 0:
|
||||
if self.streaming_state in [
|
||||
StreamingState.PARSING_ARGUMENTS,
|
||||
StreamingState.PARSING_ARGUMENTS_COMPLETED,
|
||||
StreamingState.TOOL_COMPLETE,
|
||||
StreamingState.ALL_TOOLS_COMPLETE,
|
||||
]:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage()
|
||||
else:
|
||||
# return None when the tool is not likely to be finished
|
||||
# This can occur when the name is being parsed for example
|
||||
# and we wait for the name to be complete
|
||||
# before sending the function name
|
||||
return None
|
||||
|
||||
delta = DeltaMessage()
|
||||
if additional_content:
|
||||
delta.content = additional_content
|
||||
if len(delta_tool_calls) > 0:
|
||||
delta.tool_calls = delta_tool_calls
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if delta_tool_calls and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
return delta
|
||||
|
||||
def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]:
|
||||
if delta_text == "" or delta_text is None:
|
||||
return []
|
||||
delta_function_name = None
|
||||
tool_id = None
|
||||
if self.streaming_state not in [
|
||||
StreamingState.PARSING_NAME,
|
||||
StreamingState.PARSING_ARGUMENTS,
|
||||
] and delta_text.startswith(self.bot_token):
|
||||
self.current_tool_id += 1
|
||||
self.streaming_state = StreamingState.PARSING_NAME
|
||||
delta_text = delta_text.replace(self.bot_token, "", 1)
|
||||
if self.streaming_state == StreamingState.PARSING_NAME:
|
||||
if self.current_tool_name is None:
|
||||
self.current_tool_name = ""
|
||||
# The name stops where the arguments start
|
||||
# And the arguments start with the `{` char
|
||||
if "{" in delta_text:
|
||||
tool_id = MistralToolCall.generate_random_id()
|
||||
delta_function_name = delta_text.split("{")[0]
|
||||
self.current_tool_name += delta_function_name
|
||||
delta_text = delta_text[len(delta_function_name) :]
|
||||
self.streaming_state = StreamingState.PARSING_ARGUMENTS
|
||||
else:
|
||||
# we want to send the tool name once it's complete
|
||||
self.current_tool_name += delta_text
|
||||
return []
|
||||
if self.streaming_state == StreamingState.PARSING_ARGUMENTS:
|
||||
next_function_text = None
|
||||
if self.bot_token in delta_text:
|
||||
# current tool call is over
|
||||
delta_arguments = ""
|
||||
delta_arguments += delta_text.split(self.bot_token)[0]
|
||||
next_function_text = delta_text[len(delta_arguments) :]
|
||||
self.streaming_state = StreamingState.TOOL_COMPLETE
|
||||
else:
|
||||
delta_arguments = delta_text
|
||||
ret = []
|
||||
if self.current_tool_name or delta_arguments:
|
||||
ret += [
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_tool_name, arguments=delta_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
self.current_tool_name = None
|
||||
if next_function_text:
|
||||
ret += self._generate_delta_tool_call(next_function_text)
|
||||
return ret
|
||||
# Should not happen
|
||||
return []
|
||||
|
||||
@ijson.coroutine
|
||||
def update_stream_state_pre_v11_tokenizer(self):
|
||||
while True:
|
||||
(prefix, event, value) = yield
|
||||
|
||||
if prefix == "item" and event == "start_map":
|
||||
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
|
||||
if prefix == "item" and event == "map_key" and value == "name":
|
||||
self.streaming_state = StreamingState.PARSING_NAME
|
||||
if prefix == "item.name" and event == "string":
|
||||
self.current_tool_name = value
|
||||
self.streaming_state = StreamingState.PARSING_NAME_COMPLETED
|
||||
if prefix == "item" and event == "map_key" and value == "arguments":
|
||||
self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START
|
||||
if prefix == "item.arguments" and event == "start_map":
|
||||
self.streaming_state = StreamingState.PARSING_ARGUMENTS
|
||||
if prefix == "item.arguments" and event == "end_map":
|
||||
self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED
|
||||
if prefix == "item" and event == "end_map":
|
||||
self.streaming_state = StreamingState.TOOL_COMPLETE
|
||||
if prefix == "" and event == "end_array":
|
||||
self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE
|
||||
|
||||
def _extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extracts tool calls for Mistral models
|
||||
doing tool calls of the following format:
|
||||
`[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}`
|
||||
"""
|
||||
assert self.parse_coro is not None
|
||||
content = None
|
||||
delta_tool_calls: list[DeltaToolCall] = []
|
||||
current_tool_call: DeltaToolCall = DeltaToolCall(
|
||||
index=self.current_tool_id, type="function"
|
||||
)
|
||||
current_tool_call_modified = False
|
||||
if self.bot_token_id in delta_token_ids:
|
||||
# this is the first tool call
|
||||
if not delta_text.startswith(self.bot_token):
|
||||
content = delta_text.split(self.bot_token)[0]
|
||||
delta_text = "".join(delta_text.split(self.bot_token)[1:])
|
||||
|
||||
# Cut smartly the delta text to catch the ijson events
|
||||
# as ijson does not give us the index in the text at each event.
|
||||
# We need to cut so that we know
|
||||
# where in the text the events are emitted from.
|
||||
while len(delta_text) > 0:
|
||||
streaming_state_before_parse = self.streaming_state
|
||||
|
||||
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_opening_curly_braces=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY:
|
||||
# Wait until another key is sent
|
||||
# or the current tool is completed
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_colon=1,
|
||||
stop_after_opening_curly_braces=1,
|
||||
# if the tool ends, we want to separate
|
||||
# at the start of the next tool
|
||||
)
|
||||
elif self.streaming_state == StreamingState.PARSING_NAME:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_comma=1,
|
||||
stop_after_closing_brackets=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_opening_curly_braces=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.PARSING_ARGUMENTS:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_closing_curly_braces=1,
|
||||
# we could be more clever
|
||||
# by listening to item.arguments.* start_map events
|
||||
# and know how many curly braces we can allow
|
||||
)
|
||||
elif self.streaming_state in [
|
||||
StreamingState.PARSING_ARGUMENTS_COMPLETED,
|
||||
StreamingState.PARSING_NAME_COMPLETED,
|
||||
]:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_closing_curly_braces=1,
|
||||
stop_after_closing_brackets=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.TOOL_COMPLETE:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_opening_curly_braces=1,
|
||||
stop_after_closing_brackets=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
|
||||
content = delta_text
|
||||
delta_text = ""
|
||||
else:
|
||||
delta_to_be_parsed = delta_text
|
||||
delta_text = ""
|
||||
|
||||
if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE:
|
||||
self.parse_coro.send(delta_to_be_parsed.encode("utf-8"))
|
||||
|
||||
# Given the parsed text and the possible streaming state change,
|
||||
# let's add to the tool delta
|
||||
if (
|
||||
(streaming_state_before_parse != self.streaming_state)
|
||||
and streaming_state_before_parse
|
||||
in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE]
|
||||
and self.streaming_state
|
||||
not in [
|
||||
StreamingState.ALL_TOOLS_COMPLETE,
|
||||
StreamingState.TOOL_COMPLETE,
|
||||
StreamingState.WAITING_FOR_TOOL_START,
|
||||
]
|
||||
):
|
||||
# starting a new tool call
|
||||
if current_tool_call_modified:
|
||||
if self.current_tool_mistral_id is not None:
|
||||
current_tool_call.id = self.current_tool_mistral_id
|
||||
self.current_tool_mistral_id = None
|
||||
delta_tool_calls.append(current_tool_call)
|
||||
current_tool_call_modified = False
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_mistral_id = MistralToolCall.generate_random_id()
|
||||
current_tool_call = DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
)
|
||||
if current_tool_call.function is None:
|
||||
current_tool_call.function = DeltaFunctionCall()
|
||||
|
||||
if self.current_tool_name is not None:
|
||||
# we have the complete tool name
|
||||
current_tool_call_modified = True
|
||||
current_tool_call.function.name = self.current_tool_name
|
||||
self.current_tool_name = None
|
||||
if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED:
|
||||
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
|
||||
if self.streaming_state in [
|
||||
StreamingState.PARSING_ARGUMENTS,
|
||||
StreamingState.PARSING_ARGUMENTS_COMPLETED,
|
||||
]:
|
||||
if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED:
|
||||
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
|
||||
# the delta_to_be_parsed is part of arguments.
|
||||
current_tool_call_modified = True
|
||||
if current_tool_call.function.arguments is None:
|
||||
current_tool_call.function.arguments = delta_to_be_parsed
|
||||
else:
|
||||
current_tool_call.function.arguments += delta_to_be_parsed
|
||||
if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS:
|
||||
# It's the first chunk of arg. let's lstrip it
|
||||
current_tool_call.function.arguments = (
|
||||
current_tool_call.function.arguments.lstrip()
|
||||
)
|
||||
|
||||
if current_tool_call_modified:
|
||||
if self.current_tool_mistral_id is not None:
|
||||
current_tool_call.id = self.current_tool_mistral_id
|
||||
self.current_tool_mistral_id = None
|
||||
delta_tool_calls.append(current_tool_call)
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining it's final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if delta_tool_calls and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if content or len(delta_tool_calls) > 0:
|
||||
delta_message = DeltaMessage()
|
||||
if content:
|
||||
delta_message.content = content
|
||||
if len(delta_tool_calls) > 0:
|
||||
delta_message.tool_calls = delta_tool_calls
|
||||
return delta_message
|
||||
else:
|
||||
if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
|
||||
return DeltaMessage()
|
||||
else:
|
||||
return None
|
||||
|
||||
def _split_delta(
|
||||
self,
|
||||
delta_text: str,
|
||||
stop_after_quotes: int = -1,
|
||||
stop_after_opening_curly_braces: int = -1,
|
||||
stop_after_closing_curly_braces: int = -1,
|
||||
stop_after_closing_brackets: int = -1,
|
||||
stop_after_colon: int = -1,
|
||||
stop_after_comma=-1,
|
||||
) -> tuple[str, str]:
|
||||
delta_to_be_parsed = ""
|
||||
for i, c in enumerate(delta_text):
|
||||
if c in ['"', "'"]:
|
||||
delta_to_be_parsed += c
|
||||
stop_after_quotes -= 1
|
||||
if stop_after_quotes == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == "{":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_opening_curly_braces -= 1
|
||||
if stop_after_opening_curly_braces == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == "}":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_closing_curly_braces -= 1
|
||||
if stop_after_closing_curly_braces == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == "]":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_closing_brackets -= 1
|
||||
if stop_after_closing_brackets == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == ":":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_colon -= 1
|
||||
if stop_after_colon == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == ",":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_comma -= 1
|
||||
if stop_after_comma == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
else:
|
||||
delta_to_be_parsed += c
|
||||
|
||||
return (delta_to_be_parsed, "")
|
||||
|
||||
Reference in New Issue
Block a user