[Model] Add reasoning_parser and tool_parser for Ernie45 thinking (#25027)
Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
@@ -11,6 +11,8 @@ vLLM currently supports the following reasoning models:
|
||||
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|
||||
|--------------|-------------|------------------|-------------|
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
|
||||
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
|
||||
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
|
||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
|
||||
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
||||
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ |
|
||||
|
||||
124
tests/reasoning/test_ernie45_reasoning_parser.py
Normal file
124
tests/reasoning/test_ernie45_reasoning_parser.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "ernie45"
|
||||
|
||||
REASONING_MODEL_NAME = "baidu/ERNIE-4.5-21B-A3B-Thinking"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ernie45_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
# 带 </think>,非stream
|
||||
WITH_THINK = {
|
||||
"output": "abc</think>def",
|
||||
"reasoning_content": "abc",
|
||||
"content": "def",
|
||||
}
|
||||
# 带 </think>,stream
|
||||
WITH_THINK_STREAM = {
|
||||
"output": "abc</think>def",
|
||||
"reasoning_content": "abc",
|
||||
"content": "def",
|
||||
}
|
||||
# without </think>, all is reasoning_content
|
||||
WITHOUT_THINK = {
|
||||
"output": "abc",
|
||||
"reasoning_content": "abc",
|
||||
"content": None,
|
||||
}
|
||||
# without </think>, all is reasoning_content
|
||||
WITHOUT_THINK_STREAM = {
|
||||
"output": "abc",
|
||||
"reasoning_content": "abc",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
COMPLETE_REASONING = {
|
||||
"output": "abc</think>",
|
||||
"reasoning_content": "abc",
|
||||
"content": None,
|
||||
}
|
||||
MULTILINE_REASONING = {
|
||||
"output": "abc\nABC</think>def\nDEF",
|
||||
"reasoning_content": "abc\nABC",
|
||||
"content": "def\nDEF",
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
WITH_THINK,
|
||||
id="with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
WITH_THINK_STREAM,
|
||||
id="with_think_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
WITHOUT_THINK,
|
||||
id="without_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
WITHOUT_THINK_STREAM,
|
||||
id="without_think_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTILINE_REASONING,
|
||||
id="multiline_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTILINE_REASONING,
|
||||
id="multiline_reasoning_stream",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
ernie45_tokenizer,
|
||||
):
|
||||
output = ernie45_tokenizer.tokenize(param_dict["output"])
|
||||
output_tokens: list[str] = []
|
||||
for token in output:
|
||||
one_token = ernie45_tokenizer.convert_tokens_to_string([token])
|
||||
if one_token:
|
||||
output_tokens.append(one_token)
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||
ernie45_tokenizer
|
||||
)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, output_tokens, streaming=streaming
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
359
tests/tool_use/test_ernie45_moe_tool_parser.py
Normal file
359
tests/tool_use/test_ernie45_moe_tool_parser.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import Ernie45ToolParser
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ernie45_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ernie45_tool_parser(ernie45_tokenizer):
|
||||
return Ernie45ToolParser(ernie45_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 0
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
# Compare arguments as JSON objects to handle formatting differences
|
||||
actual_args = json.loads(actual_tool_call.function.arguments)
|
||||
expected_args = json.loads(expected_tool_call.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(ernie45_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I need to call two tools to handle these two issues separately.
|
||||
</think>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"I need to call two tools to handle these two issues separately.\n</think>",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
ernie45_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
ernie45_tool_parser: Ernie45ToolParser,
|
||||
ernie45_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=ernie45_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = ernie45_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I need to call two tools to handle these two issues separately.
|
||||
</think>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"I need to call two tools to handle these two issues separately.\n</think>",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
ernie45_tool_parser,
|
||||
ernie45_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||
|
||||
tool_calls_dict = {}
|
||||
for delta_message in stream_delta_message_generator(
|
||||
ernie45_tool_parser, ernie45_tokenizer, model_output, request
|
||||
):
|
||||
if (
|
||||
delta_message.role is None
|
||||
and delta_message.content is None
|
||||
and delta_message.reasoning_content is None
|
||||
and len(delta_message.tool_calls) == 0
|
||||
):
|
||||
continue
|
||||
tool_calls = delta_message.tool_calls
|
||||
for tool_call_chunk in tool_calls:
|
||||
index = tool_call_chunk.index
|
||||
if index not in tool_calls_dict:
|
||||
if tool_call_chunk.function.arguments is None:
|
||||
tool_call_chunk.function.arguments = ""
|
||||
tool_calls_dict[index] = tool_call_chunk
|
||||
else:
|
||||
tool_calls_dict[
|
||||
index
|
||||
].function.arguments += tool_call_chunk.function.arguments
|
||||
actual_tool_calls = list(tool_calls_dict.values())
|
||||
|
||||
assert len(actual_tool_calls) > 0
|
||||
# check tool call format
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
@@ -4,6 +4,7 @@
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
|
||||
from .ernie45_tool_parser import Ernie45ToolParser
|
||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
@@ -42,6 +43,7 @@ __all__ = [
|
||||
"Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser",
|
||||
"DeepSeekV31ToolParser",
|
||||
"Ernie45ToolParser",
|
||||
"xLAMToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
|
||||
212
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
Normal file
212
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("ernie45")
|
||||
class Ernie45ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
"""
|
||||
Ernie thinking model format:
|
||||
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
|
||||
"""
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.think_end_token = "</think>"
|
||||
self.response_start_token: str = "<response>"
|
||||
self.response_end_token: str = "</response>"
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
self.newline_token: str = "<0x0A>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
self.response_start_token_id = self.vocab.get(self.response_start_token)
|
||||
self.response_end_token_id = self.vocab.get(self.response_end_token)
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self.newline_token_id = self.vocab.get(self.newline_token)
|
||||
self.parser_token_ids = [
|
||||
self.think_end_token_id,
|
||||
self.response_start_token_id,
|
||||
self.response_end_token_id,
|
||||
]
|
||||
|
||||
self._buffer = ""
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
tool_call_json_list = self.tool_call_regex.findall(model_output)
|
||||
|
||||
tool_calls = []
|
||||
for tool_call_json in tool_call_json_list:
|
||||
tool_call_dict = json.loads(tool_call_json)
|
||||
args_str = json.dumps(
|
||||
tool_call_dict.get("arguments", {}), ensure_ascii=False
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tool_call_dict.get("name", ""),
|
||||
arguments=args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
content = model_output[
|
||||
: model_output.find(self.tool_calls_start_token)
|
||||
].rstrip("\n")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
self._buffer += delta_text
|
||||
cur_text = self._buffer
|
||||
start_idx = cur_text.find(self.tool_call_start_token)
|
||||
if start_idx == -1:
|
||||
self._buffer = ""
|
||||
# At least one toolcall has been completed
|
||||
if self.current_tool_id > 0:
|
||||
cur_text = ""
|
||||
if self.current_tool_id == -1 and all(
|
||||
token_id == self.newline_token_id for token_id in previous_token_ids
|
||||
):
|
||||
cur_text = cur_text.strip("\n")
|
||||
|
||||
# handle <response> </response> when tool_call is not triggered
|
||||
# cur_text === delta_text
|
||||
content = cur_text
|
||||
if self.response_start_token_id in delta_token_ids:
|
||||
content = content.lstrip("\n")
|
||||
response_start_idx = content.find(self.response_start_token)
|
||||
content = content[response_start_idx + len(self.response_start_token) :]
|
||||
# if have </response>, remove it
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
if response_end_idx != -1:
|
||||
content = content[:response_end_idx]
|
||||
elif self.response_end_token_id in delta_token_ids:
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
content = content[:response_end_idx]
|
||||
# remove \n after </think> or <response> or </response>
|
||||
if (
|
||||
len(previous_token_ids) > 0
|
||||
and previous_token_ids[-1] in self.parser_token_ids
|
||||
) and (
|
||||
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||
):
|
||||
content = content.lstrip("\n")
|
||||
|
||||
return DeltaMessage(content=content if content else None)
|
||||
logger.debug("cur_text = %s", cur_text)
|
||||
end_idx = cur_text.find(self.tool_call_end_token)
|
||||
if end_idx != -1:
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = []
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
extracted_tool_calls = self.extract_tool_calls(
|
||||
cur_text[: end_idx + len(self.tool_call_end_token)], request
|
||||
)
|
||||
|
||||
if len(extracted_tool_calls.tool_calls) == 0:
|
||||
logger.warning("Failed to extract any tool calls.")
|
||||
return None
|
||||
tool_call = extracted_tool_calls.tool_calls[0]
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
self.streamed_args_for_tool[self.current_tool_id] = (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
content=extracted_tool_calls.content,
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_id += 1
|
||||
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
|
||||
return delta
|
||||
|
||||
self._buffer = cur_text[start_idx:]
|
||||
content = cur_text[:start_idx].rstrip("\n")
|
||||
return DeltaMessage(content=content if content else None)
|
||||
@@ -4,6 +4,7 @@
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .basic_parsers import BaseThinkingReasoningParser
|
||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .ernie45_reasoning_parser import Ernie45ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .gptoss_reasoning_parser import GptOssReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
@@ -19,6 +20,7 @@ __all__ = [
|
||||
"BaseThinkingReasoningParser",
|
||||
"ReasoningParserManager",
|
||||
"DeepSeekR1ReasoningParser",
|
||||
"Ernie45ReasoningParser",
|
||||
"GraniteReasoningParser",
|
||||
"HunyuanA13BReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
|
||||
169
vllm/reasoning/ernie45_reasoning_parser.py
Normal file
169
vllm/reasoning/ernie45_reasoning_parser.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("ernie45")
|
||||
class Ernie45ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Ernie45 thinking model.
|
||||
The Ernie45 thinking model ouput format is
|
||||
abc\n</think>\n\n<response>\ndef\n</response>\n
|
||||
or abc\n</think>\ndef
|
||||
"""
|
||||
|
||||
response_start_token: str = "<response>"
|
||||
response_end_token: str = "</response>"
|
||||
newline_token: str = "<0x0A>"
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
"""The token that starts reasoning content."""
|
||||
return "<think>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
"""The token that ends reasoning content."""
|
||||
return "</think>"
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.start_token_id = self.vocab.get(self.start_token)
|
||||
self.end_token_id = self.vocab.get(self.end_token)
|
||||
self.response_start_token_id = self.vocab.get(self.response_start_token)
|
||||
self.response_end_token_id = self.vocab.get(self.response_end_token)
|
||||
self.newline_token_id = self.vocab.get(self.newline_token)
|
||||
|
||||
self.parser_token_ids = [self.end_token_id, self.response_end_token_id]
|
||||
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Ernie45 reasoning parser could not locate think start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
Handles streaming output where previous + delta = current.
|
||||
Uses token IDs for faster processing.
|
||||
The Ernie45 thinking model ouput format is
|
||||
abc\n</think>\n\n<response>\ndef\n</response>\n
|
||||
or abc\n</think>\ndef
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'def' goes to content
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and (
|
||||
delta_token_ids[0]
|
||||
in [
|
||||
self.start_token_id,
|
||||
self.end_token_id,
|
||||
self.response_start_token_id,
|
||||
self.response_end_token_id,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
# No <think> in previous or delta, also need to check for </think>.
|
||||
# Because the model may have generated </think> without <think>
|
||||
if self.end_token_id in delta_token_ids:
|
||||
# </think> in delta with more tokens,
|
||||
# extract reasoning content and content
|
||||
think_end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:think_end_index]
|
||||
content = delta_text[think_end_index + len(self.end_token) :]
|
||||
content = content.lstrip("\n")
|
||||
response_start_idx = content.find(self.response_start_token)
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
if response_start_idx != -1:
|
||||
content = content[response_start_idx + len(self.response_start_token) :]
|
||||
if response_end_idx != -1:
|
||||
content = content[:response_end_idx]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
elif self.end_token_id in previous_token_ids:
|
||||
# </think> in previous, thinking content ends
|
||||
content = delta_text
|
||||
if self.response_start_token_id in delta_token_ids:
|
||||
content = content.lstrip("\n")
|
||||
response_start_idx = content.find(self.response_start_token)
|
||||
content = content[response_start_idx + len(self.response_start_token) :]
|
||||
# if have </response>, remove it
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
if response_end_idx != -1:
|
||||
content = content[:response_end_idx]
|
||||
elif self.response_end_token_id in delta_token_ids:
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
content = content[:response_end_idx]
|
||||
# remove \n after </think> or </response>
|
||||
if previous_token_ids[-1] in self.parser_token_ids and (
|
||||
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||
):
|
||||
content = content.lstrip("\n")
|
||||
# remove \n after </think>\n
|
||||
if (
|
||||
len(previous_token_ids) > 1
|
||||
and previous_token_ids[-2] == self.end_token_id
|
||||
) and (
|
||||
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||
):
|
||||
content = content.lstrip("\n")
|
||||
|
||||
return DeltaMessage(content=content if content else None)
|
||||
else:
|
||||
# no </think> in previous or delta, reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
The Ernie45 thinking model ouput format is
|
||||
abc\n</think>\n\n\n<response>\ndef\n</response>\n
|
||||
or abc\n</think>\ndef
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'def' goes to content
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||
"""
|
||||
reasoning_content, content = super().extract_reasoning_content(
|
||||
model_output, request
|
||||
)
|
||||
if content:
|
||||
start_idx = content.find(self.response_start_token)
|
||||
end_idx = content.rfind(self.response_end_token)
|
||||
# Simultaneously existing and in the correct order
|
||||
if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
|
||||
content = content[start_idx + len(self.response_start_token) : end_idx]
|
||||
final_content = content or None
|
||||
|
||||
return reasoning_content, final_content
|
||||
Reference in New Issue
Block a user