[Model] Add reasoning_parser and tool_parser for Ernie45 thinking (#25027)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
CSWYF3634076
2025-10-13 15:55:20 +08:00
committed by GitHub
parent 98f30b8cba
commit 782505ed8e
7 changed files with 870 additions and 0 deletions

View File

@@ -11,6 +11,8 @@ vLLM currently supports the following reasoning models:
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|--------------|-------------|------------------|-------------|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ |

View File

@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "ernie45"
REASONING_MODEL_NAME = "baidu/ERNIE-4.5-21B-A3B-Thinking"
@pytest.fixture(scope="module")
def ernie45_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# 带 </think>非stream
WITH_THINK = {
"output": "abc</think>def",
"reasoning_content": "abc",
"content": "def",
}
# 带 </think>stream
WITH_THINK_STREAM = {
"output": "abc</think>def",
"reasoning_content": "abc",
"content": "def",
}
# without </think>, all is reasoning_content
WITHOUT_THINK = {
"output": "abc",
"reasoning_content": "abc",
"content": None,
}
# without </think>, all is reasoning_content
WITHOUT_THINK_STREAM = {
"output": "abc",
"reasoning_content": "abc",
"content": None,
}
COMPLETE_REASONING = {
"output": "abc</think>",
"reasoning_content": "abc",
"content": None,
}
MULTILINE_REASONING = {
"output": "abc\nABC</think>def\nDEF",
"reasoning_content": "abc\nABC",
"content": "def\nDEF",
}
TEST_CASES = [
pytest.param(
False,
WITH_THINK,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
id="with_think_stream",
),
pytest.param(
False,
WITHOUT_THINK,
id="without_think",
),
pytest.param(
True,
WITHOUT_THINK_STREAM,
id="without_think_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_stream",
),
pytest.param(
False,
MULTILINE_REASONING,
id="multiline_reasoning",
),
pytest.param(
True,
MULTILINE_REASONING,
id="multiline_reasoning_stream",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
ernie45_tokenizer,
):
output = ernie45_tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = []
for token in output:
one_token = ernie45_tokenizer.convert_tokens_to_string([token])
if one_token:
output_tokens.append(one_token)
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
ernie45_tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
print()
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]

View File

@@ -0,0 +1,359 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from collections.abc import Generator
import pytest
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import Ernie45ToolParser
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
@pytest.fixture(scope="module")
def ernie45_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def ernie45_tool_parser(ernie45_tokenizer):
return Ernie45ToolParser(ernie45_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 0
assert actual_tool_call.type == "function"
assert actual_tool_call.function.name == expected_tool_call.function.name
# Compare arguments as JSON objects to handle formatting differences
actual_args = json.loads(actual_tool_call.function.arguments)
expected_args = json.loads(expected_tool_call.function.arguments)
assert actual_args == expected_args
def test_extract_tool_calls_no_tools(ernie45_tool_parser):
model_output = "This is a test"
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
)
],
None,
),
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
None,
),
(
"""I need to call two tools to handle these two issues separately.
</think>
<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
"I need to call two tools to handle these two issues separately.\n</think>",
),
],
)
def test_extract_tool_calls(
ernie45_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def stream_delta_message_generator(
ernie45_tool_parser: Ernie45ToolParser,
ernie45_tokenizer: AnyTokenizer,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=ernie45_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = ernie45_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
)
],
None,
),
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
None,
),
(
"""I need to call two tools to handle these two issues separately.
</think>
<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
"I need to call two tools to handle these two issues separately.\n</think>",
),
],
)
def test_extract_tool_calls_streaming_incremental(
ernie45_tool_parser,
ernie45_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
tool_calls_dict = {}
for delta_message in stream_delta_message_generator(
ernie45_tool_parser, ernie45_tokenizer, model_output, request
):
if (
delta_message.role is None
and delta_message.content is None
and delta_message.reasoning_content is None
and len(delta_message.tool_calls) == 0
):
continue
tool_calls = delta_message.tool_calls
for tool_call_chunk in tool_calls:
index = tool_call_chunk.index
if index not in tool_calls_dict:
if tool_call_chunk.function.arguments is None:
tool_call_chunk.function.arguments = ""
tool_calls_dict[index] = tool_call_chunk
else:
tool_calls_dict[
index
].function.arguments += tool_call_chunk.function.arguments
actual_tool_calls = list(tool_calls_dict.values())
assert len(actual_tool_calls) > 0
# check tool call format
assert_tool_calls(actual_tool_calls, expected_tool_calls)

View File

@@ -4,6 +4,7 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
from .ernie45_tool_parser import Ernie45ToolParser
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
@@ -42,6 +43,7 @@ __all__ = [
"Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser",
"DeepSeekV31ToolParser",
"Ernie45ToolParser",
"xLAMToolParser",
"MinimaxToolParser",
"KimiK2ToolParser",

View File

@@ -0,0 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("ernie45")
class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
"""
Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
"""
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id = -1
self.streamed_args_for_tool: list[str] = []
self.think_end_token = "</think>"
self.response_start_token: str = "<response>"
self.response_end_token: str = "</response>"
self.tool_call_start_token = "<tool_call>"
self.tool_call_end_token = "</tool_call>"
self.tool_calls_start_token = self.tool_call_start_token
self.newline_token: str = "<0x0A>"
self.tool_call_regex = re.compile(
r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.think_end_token_id = self.vocab.get(self.think_end_token)
self.response_start_token_id = self.vocab.get(self.response_start_token)
self.response_end_token_id = self.vocab.get(self.response_end_token)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.newline_token_id = self.vocab.get(self.newline_token)
self.parser_token_ids = [
self.think_end_token_id,
self.response_start_token_id,
self.response_end_token_id,
]
self._buffer = ""
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
tool_call_json_list = self.tool_call_regex.findall(model_output)
tool_calls = []
for tool_call_json in tool_call_json_list:
tool_call_dict = json.loads(tool_call_json)
args_str = json.dumps(
tool_call_dict.get("arguments", {}), ensure_ascii=False
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=tool_call_dict.get("name", ""),
arguments=args_str,
),
)
)
content = model_output[
: model_output.find(self.tool_calls_start_token)
].rstrip("\n")
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
self._buffer += delta_text
cur_text = self._buffer
start_idx = cur_text.find(self.tool_call_start_token)
if start_idx == -1:
self._buffer = ""
# At least one toolcall has been completed
if self.current_tool_id > 0:
cur_text = ""
if self.current_tool_id == -1 and all(
token_id == self.newline_token_id for token_id in previous_token_ids
):
cur_text = cur_text.strip("\n")
# handle <response> </response> when tool_call is not triggered
# cur_text === delta_text
content = cur_text
if self.response_start_token_id in delta_token_ids:
content = content.lstrip("\n")
response_start_idx = content.find(self.response_start_token)
content = content[response_start_idx + len(self.response_start_token) :]
# if have </response>, remove it
response_end_idx = content.rfind(self.response_end_token)
if response_end_idx != -1:
content = content[:response_end_idx]
elif self.response_end_token_id in delta_token_ids:
response_end_idx = content.rfind(self.response_end_token)
content = content[:response_end_idx]
# remove \n after </think> or <response> or </response>
if (
len(previous_token_ids) > 0
and previous_token_ids[-1] in self.parser_token_ids
) and (
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
):
content = content.lstrip("\n")
return DeltaMessage(content=content if content else None)
logger.debug("cur_text = %s", cur_text)
end_idx = cur_text.find(self.tool_call_end_token)
if end_idx != -1:
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
extracted_tool_calls = self.extract_tool_calls(
cur_text[: end_idx + len(self.tool_call_end_token)], request
)
if len(extracted_tool_calls.tool_calls) == 0:
logger.warning("Failed to extract any tool calls.")
return None
tool_call = extracted_tool_calls.tool_calls[0]
self.prev_tool_call_arr[self.current_tool_id] = {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
self.streamed_args_for_tool[self.current_tool_id] = (
tool_call.function.arguments
)
delta = DeltaMessage(
content=extracted_tool_calls.content,
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=tool_call.id,
type=tool_call.type,
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
],
)
self.current_tool_id += 1
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
return delta
self._buffer = cur_text[start_idx:]
content = cur_text[:start_idx].rstrip("\n")
return DeltaMessage(content=content if content else None)

View File

@@ -4,6 +4,7 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .basic_parsers import BaseThinkingReasoningParser
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .ernie45_reasoning_parser import Ernie45ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .gptoss_reasoning_parser import GptOssReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
@@ -19,6 +20,7 @@ __all__ = [
"BaseThinkingReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"Ernie45ReasoningParser",
"GraniteReasoningParser",
"HunyuanA13BReasoningParser",
"Qwen3ReasoningParser",

View File

@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
logger = init_logger(__name__)
@ReasoningParserManager.register_module("ernie45")
class Ernie45ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Ernie45 thinking model.
The Ernie45 thinking model ouput format is
abc\n</think>\n\n<response>\ndef\n</response>\n
or abc\n</think>\ndef
"""
response_start_token: str = "<response>"
response_end_token: str = "</response>"
newline_token: str = "<0x0A>"
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<think>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
self.response_start_token_id = self.vocab.get(self.response_start_token)
self.response_end_token_id = self.vocab.get(self.response_end_token)
self.newline_token_id = self.vocab.get(self.newline_token)
self.parser_token_ids = [self.end_token_id, self.response_end_token_id]
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"Ernie45 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
The Ernie45 thinking model ouput format is
abc\n</think>\n\n<response>\ndef\n</response>\n
or abc\n</think>\ndef
- 'abc' goes to reasoning_content
- 'def' goes to content
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (
delta_token_ids[0]
in [
self.start_token_id,
self.end_token_id,
self.response_start_token_id,
self.response_end_token_id,
]
):
return None
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
think_end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:think_end_index]
content = delta_text[think_end_index + len(self.end_token) :]
content = content.lstrip("\n")
response_start_idx = content.find(self.response_start_token)
response_end_idx = content.rfind(self.response_end_token)
if response_start_idx != -1:
content = content[response_start_idx + len(self.response_start_token) :]
if response_end_idx != -1:
content = content[:response_end_idx]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
content = delta_text
if self.response_start_token_id in delta_token_ids:
content = content.lstrip("\n")
response_start_idx = content.find(self.response_start_token)
content = content[response_start_idx + len(self.response_start_token) :]
# if have </response>, remove it
response_end_idx = content.rfind(self.response_end_token)
if response_end_idx != -1:
content = content[:response_end_idx]
elif self.response_end_token_id in delta_token_ids:
response_end_idx = content.rfind(self.response_end_token)
content = content[:response_end_idx]
# remove \n after </think> or </response>
if previous_token_ids[-1] in self.parser_token_ids and (
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
):
content = content.lstrip("\n")
# remove \n after </think>\n
if (
len(previous_token_ids) > 1
and previous_token_ids[-2] == self.end_token_id
) and (
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
):
content = content.lstrip("\n")
return DeltaMessage(content=content if content else None)
else:
# no </think> in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[str | None, str | None]:
"""
Extract reasoning content from the model output.
The Ernie45 thinking model ouput format is
abc\n</think>\n\n\n<response>\ndef\n</response>\n
or abc\n</think>\ndef
- 'abc' goes to reasoning_content
- 'def' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
reasoning_content, content = super().extract_reasoning_content(
model_output, request
)
if content:
start_idx = content.find(self.response_start_token)
end_idx = content.rfind(self.response_end_token)
# Simultaneously existing and in the correct order
if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
content = content[start_idx + len(self.response_start_token) : end_idx]
final_content = content or None
return reasoning_content, final_content