Files
vllm/tests/tokenizers_/test_mistral.py
2025-11-30 14:59:47 +08:00

2266 lines
69 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.tokenizers.mistral import (
MistralTokenizer,
_prepare_apply_chat_template_tools_and_messages,
)
@pytest.mark.parametrize(
"openai_request,expected_mistral_output",
[
(
{
"messages": [
{
"role": "user",
"content": "What is the current local date and time?",
}
],
"tools": [
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
},
}
],
},
(
[
{
"role": "user",
"content": "What is the current local date and time?",
}
],
[
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
}
],
),
),
(
{
"messages": [
{
"role": "user",
"content": "What is the current local date and time?",
}
],
"tools": [
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
}
],
},
(
[
{
"role": "user",
"content": "What is the current local date and time?",
}
],
[
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
}
],
),
),
],
)
def test_prepare_apply_chat_template_tools_and_messages(
openai_request, expected_mistral_output
):
actual_request = _prepare_apply_chat_template_tools_and_messages(
openai_request["messages"], openai_request["tools"]
)
assert actual_request == expected_mistral_output
# Tool use with list content and reasoning
@pytest.mark.parametrize(
"openai_request,expected_mistral_output",
[
(
{
"messages": [
{
"role": "user",
"content": "What's the weather in Paris?",
},
{
"role": "assistant",
"reasoning": None,
"content": None,
"tool_calls": [
{
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": [{"type": "text", "text": "Rainy"}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
},
(
[
{
"role": "user",
"content": "What's the weather in Paris?",
},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": [{"type": "text", "text": "Rainy"}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
),
)
],
)
def test_prepare_apply_chat_template_tools_and_messages_list_content(
openai_request, expected_mistral_output
):
actual_request = _prepare_apply_chat_template_tools_and_messages(
openai_request["messages"], openai_request["tools"]
)
assert actual_request == expected_mistral_output
def test_prepare_apply_chat_template_generation_prompt_and_continue():
messages = [{"role": "assistant", "content": "Hello"}]
tools: list[dict[str, Any]] = []
with pytest.raises(ValueError):
_prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=True
)
messages = [{"role": "user", "content": "Hello"}]
out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=True
)
assert out_messages == [{"role": "user", "content": "Hello"}]
with pytest.raises(ValueError):
_prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=True, continue_final_message=True
)
messages = [{"role": "assistant", "content": "Hello"}]
out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=False, continue_final_message=True
)
assert out_messages == [{"role": "assistant", "content": "Hello"}]
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValueError):
_prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=False, continue_final_message=True
)
@pytest.fixture(scope="module")
def mistral_tokenizer(request) -> MistralTokenizer:
return MistralTokenizer.from_pretrained(request.param)
@pytest.mark.parametrize(
"mistral_tokenizer",
["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"],
indirect=True,
)
class TestMistralTokenizer:
def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer):
if mistral_tokenizer.is_tekken:
assert mistral_tokenizer.all_special_tokens == [
"<unk>",
"<s>",
"</s>",
"[INST]",
"[/INST]",
"[AVAILABLE_TOOLS]",
"[/AVAILABLE_TOOLS]",
"[TOOL_RESULTS]",
"[/TOOL_RESULTS]",
"[TOOL_CALLS]",
"[IMG]",
"<pad>",
"[IMG_BREAK]",
"[IMG_END]",
"[PREFIX]",
"[MIDDLE]",
"[SUFFIX]",
"[SYSTEM_PROMPT]",
"[/SYSTEM_PROMPT]",
"[TOOL_CONTENT]",
] + [f"<SPECIAL_{i}>" for i in range(20, 32)] + [
"[ARGS]",
"[CALL_ID]",
"[THINK]",
"[/THINK]",
] + [f"<SPECIAL_{i}>" for i in range(36, 1000)]
else:
assert mistral_tokenizer.all_special_tokens == [
"<s>",
"</s>",
"[INST]",
"[/INST]",
"[TOOL_CALLS]",
"[AVAILABLE_TOOLS]",
"[/AVAILABLE_TOOLS]",
"[TOOL_RESULTS]",
"[/TOOL_RESULTS]",
] + [f"[control_{i}]" for i in range(8, 769)]
def get_vocab(self, mistral_tokenizer: MistralTokenizer):
assert (
mistral_tokenizer.get_vocab()
== mistral_tokenizer.transformers_tokenizer.get_vocab()
)
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
assert mistral_tokenizer.get_added_vocab() == {}
def test_encode(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[1, 22177, 4304, 2662]
if mistral_tokenizer.is_tekken
else [1, 23325, 2294, 1686]
)
assert mistral_tokenizer.encode("Hello world !") == token_ids
assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-1]
assert (
mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3)
== token_ids[:-1]
)
assert (
mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3)
== token_ids
)
assert (
mistral_tokenizer.encode("Hello world !", add_special_tokens=True)
== token_ids
)
assert (
mistral_tokenizer.encode(
"Hello world !", add_special_tokens=True, max_length=3
)
== token_ids[:-1]
)
assert (
mistral_tokenizer.encode(
"Hello world !", add_special_tokens=True, truncation=False, max_length=3
)
== token_ids
)
assert (
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
== token_ids[1:]
)
assert mistral_tokenizer.encode("", add_special_tokens=False) == []
def test_call(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[1, 22177, 4304, 2662]
if mistral_tokenizer.is_tekken
else [1, 23325, 2294, 1686]
)
attn_mask = [1 for _ in range(len(token_ids))]
# Test 1: no special tokens
assert mistral_tokenizer("Hello world !", add_special_tokens=False) == {
"attention_mask": attn_mask[1:],
"input_ids": token_ids[1:],
}
# Test 2: special tokens
assert mistral_tokenizer("Hello world !", add_special_tokens=True) == {
"attention_mask": attn_mask,
"input_ids": token_ids,
}
# Test 3: special tokens + truncation
assert mistral_tokenizer(
"Hello world !", add_special_tokens=True, truncation=True, max_length=3
) == {
"attention_mask": attn_mask[:-1],
"input_ids": token_ids[:-1],
}
# Test 4: special tokens + no truncation + max length
assert mistral_tokenizer(
"Hello world !", add_special_tokens=True, max_length=3
) == {
"attention_mask": attn_mask,
"input_ids": token_ids,
}
# Test 5: empty string
assert mistral_tokenizer("", add_special_tokens=False) == {
"attention_mask": [],
"input_ids": [],
}
with pytest.raises(
ValueError,
match=(r"`text_pair` is not supported by `MistralTokenizer.__call__`."),
):
mistral_tokenizer("Hello world !", "invalid pair")
@pytest.mark.parametrize(
"openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
[
(
{
"messages": [
{
"role": "user",
"content": "Hello world !",
}
],
},
True,
False,
([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]),
("<s>[INST]▁Hello▁world▁![/INST]", ("<s>[INST]Hello world ![/INST]")),
),
(
{
"messages": [
{
"role": "system",
"content": "I am an AI",
},
{
"role": "user",
"content": "Hello world !",
},
],
},
True,
False,
(
[1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4],
[1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4],
),
(
"<s>[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]",
(
"<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]" # noqa: E501
),
),
),
(
{
"messages": [
{
"role": "system",
"content": "I am an AI",
},
{
"role": "user",
"content": "Hello world !",
},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
},
True,
False,
(
[
1,
6,
1501,
7567,
1891,
2032,
1113,
3396,
1316,
1113,
3396,
2032,
10598,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
7286,
2032,
1113,
2226,
29481,
1040,
2636,
8854,
1065,
1032,
3758,
9959,
1113,
12206,
2032,
10598,
1891,
2032,
1113,
3582,
1316,
1113,
11491,
2032,
10598,
19141,
2032,
10598,
1891,
2032,
1113,
2195,
1316,
1113,
7286,
2032,
1113,
1782,
3758,
1909,
29507,
11549,
1113,
11661,
2032,
8135,
19141,
3010,
1743,
10925,
7,
3,
1083,
1605,
1164,
16875,
781,
781,
16998,
2294,
1686,
4,
],
[
1,
17,
1073,
1855,
1420,
26554,
18,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
1689,
1095,
45629,
1897,
1429,
14653,
2811,
1429,
1071,
3083,
1278,
3519,
17253,
1294,
1261,
5970,
39249,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
29363,
2811,
16753,
4994,
2811,
1429,
3607,
1897,
1429,
14653,
2811,
1429,
1784,
5970,
2564,
1034,
47579,
1429,
15760,
2811,
12161,
29363,
4964,
2821,
27028,
6,
3,
22177,
4304,
2662,
4,
],
),
(
'<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]',
(
'<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]' # noqa: E501
),
),
),
(
{
"messages": [
{
"role": "system",
"content": "I am an AI",
},
{
"role": "user",
"content": "Hello world !",
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "123456789",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "123456789",
"content": '{"temperature": 20, "unit": "celsius"}',
},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
},
True,
False,
(
[
1,
6,
1501,
7567,
1891,
2032,
1113,
3396,
1316,
1113,
3396,
2032,
10598,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
7286,
2032,
1113,
2226,
29481,
1040,
2636,
8854,
1065,
1032,
3758,
9959,
1113,
12206,
2032,
10598,
1891,
2032,
1113,
3582,
1316,
1113,
11491,
2032,
10598,
19141,
2032,
10598,
1891,
2032,
1113,
2195,
1316,
1113,
7286,
2032,
1113,
1782,
3758,
1909,
29507,
11549,
1113,
11661,
2032,
8135,
19141,
3010,
1743,
10925,
7,
3,
1083,
1605,
1164,
16875,
781,
781,
16998,
2294,
1686,
4,
5,
1501,
7567,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
17452,
2032,
10598,
19141,
2032,
1113,
4684,
1046,
8474,
1113,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
29507,
10925,
2,
8,
10598,
4557,
2032,
10598,
29475,
17329,
2032,
29473,
29518,
29502,
29493,
1113,
6074,
2032,
1113,
29485,
1958,
3938,
8474,
1113,
3613,
29498,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
18163,
9,
],
[
1,
17,
1073,
1855,
1420,
26554,
18,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
1689,
1095,
45629,
1897,
1429,
14653,
2811,
1429,
1071,
3083,
1278,
3519,
17253,
1294,
1261,
5970,
39249,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
29363,
2811,
16753,
4994,
2811,
1429,
3607,
1897,
1429,
14653,
2811,
1429,
1784,
5970,
2564,
1034,
47579,
1429,
15760,
2811,
12161,
29363,
4964,
2821,
27028,
6,
3,
22177,
4304,
2662,
4,
9,
1689,
1095,
45629,
32,
19227,
29363,
2811,
1429,
42572,
46005,
2,
7,
19227,
113824,
2811,
1032,
1050,
1048,
1044,
1429,
8979,
2811,
1429,
1099,
79092,
46005,
8,
],
),
(
'<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}]</s>[TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]',
(
'<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}</s>[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]' # noqa: E501
),
),
),
(
{
"messages": [
{
"role": "user",
"content": "Hello world !",
},
{
"role": "assistant",
"content": "Hello ",
},
],
},
False,
True,
(
[1, 3, 23325, 2294, 1686, 4, 23325],
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
),
(
"<s>[INST]▁Hello▁world▁![/INST]▁Hello",
("<s>[INST]Hello world ![/INST]Hello</s>"),
),
),
],
)
def test_apply_chat_template(
self,
mistral_tokenizer: MistralTokenizer,
openai_request: dict[str, Any],
add_generation_prompt: bool,
continue_final_message: bool,
expected_output: tuple[list[int], list[int]],
decoded_expected_output: tuple[str, str],
):
actual_output = mistral_tokenizer.apply_chat_template(
openai_request["messages"],
tools=openai_request.get("tools", []),
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
decoded_actual_output = mistral_tokenizer.tokenizer.decode(
actual_output, SpecialTokenPolicy.KEEP
)
assert actual_output == expected_output[mistral_tokenizer.is_tekken]
assert (
decoded_actual_output
== decoded_expected_output[mistral_tokenizer.is_tekken]
)
def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer):
messages = [{"role": "user", "content": "Hello world !"}]
with pytest.raises(ValueError):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=True,
continue_final_message=True,
)
with pytest.raises(ValueError):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=False,
continue_final_message=True,
)
messages = [
{"role": "user", "content": "Hello world !"},
{"role": "assistant", "content": "Hello "},
]
with pytest.raises(ValueError):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=True,
continue_final_message=False,
)
messages = [
{"role": "user", "content": "Hello world !"},
{"role": "assistant", "content": "Hello "},
]
with pytest.raises(InvalidMessageStructureException):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=False,
continue_final_message=False,
)
@pytest.mark.parametrize(
"skip_special_tokens,expected_tokens",
(
(
False,
(
"<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>",
"<s>[INST]Hello world ![/INST]Hello</s>",
),
),
(True, ("Hello world ! Hello", "Hello world !Hello")),
),
)
def test_decode(
self,
mistral_tokenizer: MistralTokenizer,
skip_special_tokens: bool,
expected_tokens: tuple[str, str],
):
ids = (
[1, 3, 23325, 2294, 1686, 4, 23325, 2],
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
)
assert (
mistral_tokenizer.decode(
ids[mistral_tokenizer.is_tekken],
skip_special_tokens=skip_special_tokens,
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
assert (
mistral_tokenizer.decode(
ids[mistral_tokenizer.is_tekken],
skip_special_tokens=skip_special_tokens,
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
def test_decode_empty(
self,
mistral_tokenizer: MistralTokenizer,
):
assert (
mistral_tokenizer.decode(
[],
)
== ""
)
def test_decode_int(
self,
mistral_tokenizer: MistralTokenizer,
):
ids = 1
assert (
mistral_tokenizer.decode(
ids,
skip_special_tokens=False,
)
== "<s>"
)
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
tokens = (
[
"<s>",
"[AVAILABLE_TOOLS]",
"▁[",
'{"',
"type",
'":',
'"',
"function",
'",',
'"',
"function",
'":',
'{"',
"name",
'":',
'"',
"get",
"_",
"we",
"ather",
'",',
'"',
"description",
'":',
'"',
"Get",
"s",
"▁the",
"▁current",
"▁weather",
"▁in",
"▁a",
"▁city",
'.",',
'"',
"parameters",
'":',
'{"',
"type",
'":',
'"',
"object",
'",',
'"',
"properties",
'":',
'{"',
"city",
'":',
'{"',
"type",
'":',
'"',
"string",
'",',
'"',
"description",
'":',
'"',
"The",
"▁city",
"▁name",
'"',
"}},",
'"',
"required",
'":',
'▁["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"▁I",
"▁am",
"▁an",
"▁AI",
"<0x0A>",
"<0x0A>",
"Hello",
"▁world",
"▁!",
"[/INST]",
"[TOOL_CALLS]",
"▁[",
'{"',
"name",
'":',
'"',
"get",
"_",
"we",
"ather",
'",',
'"',
"arguments",
'":',
'{"',
"city",
'":',
'"',
"Par",
"is",
'"},',
'"',
"id",
'":',
'"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"',
"}]",
"</s>",
"[TOOL_RESULTS]",
'{"',
"content",
'":',
'{"',
"t",
"emperature",
'":',
"",
"2",
"0",
",",
'"',
"unit",
'":',
'"',
"c",
"els",
"ius",
'"},',
'"',
"call",
"_",
"id",
'":',
'"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"}',
"[/TOOL_RESULTS]",
],
[
"<s>",
"[SYSTEM_PROMPT]",
"I",
" am",
" an",
" AI",
"[/SYSTEM_PROMPT]",
"[AVAILABLE_TOOLS]",
"[",
'{"',
"type",
'":',
' "',
"function",
'",',
' "',
"function",
'":',
' {"',
"name",
'":',
' "',
"get",
"_",
"weather",
'",',
' "',
"description",
'":',
' "',
"G",
"ets",
" the",
" current",
" weather",
" in",
" a",
" city",
'.",',
' "',
"parameters",
'":',
' {"',
"type",
'":',
' "',
"object",
'",',
' "',
"properties",
'":',
' {"',
"city",
'":',
' {"',
"type",
'":',
' "',
"string",
'",',
' "',
"description",
'":',
' "',
"The",
" city",
" name",
'"',
"}},",
' "',
"required",
'":',
' ["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"Hello",
" world",
" !",
"[/INST]",
"[TOOL_CALLS]",
"get",
"_",
"weather",
"[ARGS]",
'{"',
"city",
'":',
' "',
"Paris",
'"}',
"</s>",
"[TOOL_RESULTS]",
'{"',
"temperature",
'":',
" ",
"2",
"0",
",",
' "',
"unit",
'":',
' "',
"c",
"elsius",
'"}',
"[/TOOL_RESULTS]",
],
)
expected_strings = (
'[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}', # noqa: E501
'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}', # noqa: E501
)
assert (
mistral_tokenizer.convert_tokens_to_string(
tokens[mistral_tokenizer.is_tekken]
)
== expected_strings[mistral_tokenizer.is_tekken]
)
assert mistral_tokenizer.convert_tokens_to_string([]) == ""
@pytest.mark.parametrize(
"skip_special_tokens,tuple_expected_tokens",
(
(
True,
(
[
"▁[",
'{"',
"type",
'":',
'"',
"function",
'",',
'"',
"function",
'":',
'{"',
"name",
'":',
'"',
"get",
"_",
"we",
"ather",
'",',
'"',
"description",
'":',
'"',
"Get",
"s",
"▁the",
"▁current",
"▁weather",
"▁in",
"▁a",
"▁city",
'.",',
'"',
"parameters",
'":',
'{"',
"type",
'":',
'"',
"object",
'",',
'"',
"properties",
'":',
'{"',
"city",
'":',
'{"',
"type",
'":',
'"',
"string",
'",',
'"',
"description",
'":',
'"',
"The",
"▁city",
"▁name",
'"',
"}},",
'"',
"required",
'":',
'▁["',
"city",
'"]',
"}}",
"}]",
"▁I",
"▁am",
"▁an",
"▁AI",
"<0x0A>",
"<0x0A>",
"Hello",
"▁world",
"▁!",
"[TOOL_CALLS]",
"▁[",
'{"',
"name",
'":',
'"',
"get",
"_",
"we",
"ather",
'",',
'"',
"arguments",
'":',
'{"',
"city",
'":',
'"',
"Par",
"is",
'"},',
'"',
"id",
'":',
'"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"',
"}]",
'{"',
"content",
'":',
'{"',
"t",
"emperature",
'":',
"",
"2",
"0",
",",
'"',
"unit",
'":',
'"',
"c",
"els",
"ius",
'"},',
'"',
"call",
"_",
"id",
'":',
'"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"}',
],
[
"I",
" am",
" an",
" AI",
"[",
'{"',
"type",
'":',
' "',
"function",
'",',
' "',
"function",
'":',
' {"',
"name",
'":',
' "',
"get",
"_",
"weather",
'",',
' "',
"description",
'":',
' "',
"G",
"ets",
" the",
" current",
" weather",
" in",
" a",
" city",
'.",',
' "',
"parameters",
'":',
' {"',
"type",
'":',
' "',
"object",
'",',
' "',
"properties",
'":',
' {"',
"city",
'":',
' {"',
"type",
'":',
' "',
"string",
'",',
' "',
"description",
'":',
' "',
"The",
" city",
" name",
'"',
"}},",
' "',
"required",
'":',
' ["',
"city",
'"]',
"}}",
"}]",
"Hello",
" world",
" !",
"[TOOL_CALLS]",
"get",
"_",
"weather",
'{"',
"city",
'":',
' "',
"Paris",
'"}',
'{"',
"temperature",
'":',
" ",
"2",
"0",
",",
' "',
"unit",
'":',
' "',
"c",
"elsius",
'"}',
],
),
),
(
False,
(
[
"<s>",
"[AVAILABLE_TOOLS]",
"▁[",
'{"',
"type",
'":',
'"',
"function",
'",',
'"',
"function",
'":',
'{"',
"name",
'":',
'"',
"get",
"_",
"we",
"ather",
'",',
'"',
"description",
'":',
'"',
"Get",
"s",
"▁the",
"▁current",
"▁weather",
"▁in",
"▁a",
"▁city",
'.",',
'"',
"parameters",
'":',
'{"',
"type",
'":',
'"',
"object",
'",',
'"',
"properties",
'":',
'{"',
"city",
'":',
'{"',
"type",
'":',
'"',
"string",
'",',
'"',
"description",
'":',
'"',
"The",
"▁city",
"▁name",
'"',
"}},",
'"',
"required",
'":',
'▁["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"▁I",
"▁am",
"▁an",
"▁AI",
"<0x0A>",
"<0x0A>",
"Hello",
"▁world",
"▁!",
"[/INST]",
"[TOOL_CALLS]",
"▁[",
'{"',
"name",
'":',
'"',
"get",
"_",
"we",
"ather",
'",',
'"',
"arguments",
'":',
'{"',
"city",
'":',
'"',
"Par",
"is",
'"},',
'"',
"id",
'":',
'"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"',
"}]",
"</s>",
"[TOOL_RESULTS]",
'{"',
"content",
'":',
'{"',
"t",
"emperature",
'":',
"",
"2",
"0",
",",
'"',
"unit",
'":',
'"',
"c",
"els",
"ius",
'"},',
'"',
"call",
"_",
"id",
'":',
'"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"}',
"[/TOOL_RESULTS]",
],
[
"<s>",
"[SYSTEM_PROMPT]",
"I",
" am",
" an",
" AI",
"[/SYSTEM_PROMPT]",
"[AVAILABLE_TOOLS]",
"[",
'{"',
"type",
'":',
' "',
"function",
'",',
' "',
"function",
'":',
' {"',
"name",
'":',
' "',
"get",
"_",
"weather",
'",',
' "',
"description",
'":',
' "',
"G",
"ets",
" the",
" current",
" weather",
" in",
" a",
" city",
'.",',
' "',
"parameters",
'":',
' {"',
"type",
'":',
' "',
"object",
'",',
' "',
"properties",
'":',
' {"',
"city",
'":',
' {"',
"type",
'":',
' "',
"string",
'",',
' "',
"description",
'":',
' "',
"The",
" city",
" name",
'"',
"}},",
' "',
"required",
'":',
' ["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"Hello",
" world",
" !",
"[/INST]",
"[TOOL_CALLS]",
"get",
"_",
"weather",
"[ARGS]",
'{"',
"city",
'":',
' "',
"Paris",
'"}',
"</s>",
"[TOOL_RESULTS]",
'{"',
"temperature",
'":',
" ",
"2",
"0",
",",
' "',
"unit",
'":',
' "',
"c",
"elsius",
'"}',
"[/TOOL_RESULTS]",
],
),
),
),
)
def test_convert_ids_to_tokens(
self,
mistral_tokenizer: MistralTokenizer,
skip_special_tokens: bool,
tuple_expected_tokens: tuple[list[str], list[str]],
):
tuple_ids = (
[
1,
6,
1501,
7567,
1891,
2032,
1113,
3396,
1316,
1113,
3396,
2032,
10598,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
7286,
2032,
1113,
2226,
29481,
1040,
2636,
8854,
1065,
1032,
3758,
9959,
1113,
12206,
2032,
10598,
1891,
2032,
1113,
3582,
1316,
1113,
11491,
2032,
10598,
19141,
2032,
10598,
1891,
2032,
1113,
2195,
1316,
1113,
7286,
2032,
1113,
1782,
3758,
1909,
29507,
11549,
1113,
11661,
2032,
8135,
19141,
3010,
1743,
10925,
7,
3,
1083,
1605,
1164,
16875,
781,
781,
16998,
2294,
1686,
4,
5,
1501,
7567,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
17452,
2032,
10598,
19141,
2032,
1113,
4684,
1046,
8474,
1113,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
29507,
10925,
2,
8,
10598,
4557,
2032,
10598,
29475,
17329,
2032,
29473,
29518,
29502,
29493,
1113,
6074,
2032,
1113,
29485,
1958,
3938,
8474,
1113,
3613,
29498,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
18163,
9,
],
[
1,
17,
1073,
1855,
1420,
26554,
18,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
1689,
1095,
45629,
1897,
1429,
14653,
2811,
1429,
1071,
3083,
1278,
3519,
17253,
1294,
1261,
5970,
39249,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
29363,
2811,
16753,
4994,
2811,
1429,
3607,
1897,
1429,
14653,
2811,
1429,
1784,
5970,
2564,
1034,
47579,
1429,
15760,
2811,
12161,
29363,
4964,
2821,
27028,
6,
3,
22177,
4304,
2662,
4,
9,
1689,
1095,
45629,
32,
19227,
29363,
2811,
1429,
42572,
46005,
2,
7,
19227,
113824,
2811,
1032,
1050,
1048,
1044,
1429,
8979,
2811,
1429,
1099,
79092,
46005,
8,
],
)
ids = tuple_ids[mistral_tokenizer.is_tekken]
expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken]
actual_tokens = mistral_tokenizer.convert_ids_to_tokens(
ids, skip_special_tokens=skip_special_tokens
)
assert actual_tokens == expected_tokens
assert mistral_tokenizer.convert_ids_to_tokens([]) == []